[源码解析] PyTorch 分布式(11) —– DistributedDataParallel 之 构建Reducer和Join操作([source code analysis] pytorch distributed (11) — Construction of distributeddataparallel reducer and join operation)

[源码解析] PyTorch 分布式(11) —– DistributedDataParallel 之 构建Reducer和Join操作

  • [源码解析] PyTorch 分布式(11) —– DistributedDataParallel 之 构建Reducer和Join操作0x00 摘要0x01 引论1.1 调用1.2 参数说明0x02 Reducer 初始化2.1 构造函数2.2 初始化桶2.3 初始化视图2.3.1 BucketReplica成员变量2.3.2 调用2.4 初始化本地使用变量0x03 静态图3.1 缘由3.2 使用3.2 Reducer0x04 重建桶4.1 为何要重建4.2 准备重建4.3 重建4.4 何时设定重建4.5 直接调用0x05 Join5.1 缘起5.2 使用5.2.1 DistributedDataParallel5.2.2 ZeroRedundancyOptimizer5.3 原理5.3.1 Joinable5.3.2JoinHook5.3.2.1 ZeroRedundancyOptimizer5.3.3 Join5.4 例子0xFF 参考
  • 0x00 摘要
  • 0x01 引论1.1 调用1.2 参数说明
  • 1.1 调用
  • 1.2 参数说明
  • 0x02 Reducer 初始化2.1 构造函数2.2 初始化桶2.3 初始化视图2.3.1 BucketReplica成员变量2.3.2 调用2.4 初始化本地使用变量
  • 2.1 构造函数
  • 2.2 初始化桶
  • 2.3 初始化视图2.3.1 BucketReplica成员变量2.3.2 调用
  • 2.3.1 BucketReplica成员变量
  • 2.3.2 调用
  • 2.4 初始化本地使用变量
  • 0x03 静态图3.1 缘由3.2 使用3.2 Reducer
  • 3.1 缘由
  • 3.2 使用
  • 3.2 Reducer
  • 0x04 重建桶4.1 为何要重建4.2 准备重建4.3 重建4.4 何时设定重建4.5 直接调用
  • 4.1 为何要重建
  • 4.2 准备重建
  • 4.3 重建
  • 4.4 何时设定重建
  • 4.5 直接调用
  • 0x05 Join5.1 缘起5.2 使用5.2.1 DistributedDataParallel5.2.2 ZeroRedundancyOptimizer5.3 原理5.3.1 Joinable5.3.2JoinHook5.3.2.1 ZeroRedundancyOptimizer5.3.3 Join5.4 例子
  • 5.1 缘起
  • 5.2 使用5.2.1 DistributedDataParallel5.2.2 ZeroRedundancyOptimizer
  • 5.2.1 DistributedDataParallel
  • 5.2.2 ZeroRedundancyOptimizer
  • 5.3 原理5.3.1 Joinable5.3.2JoinHook5.3.2.1 ZeroRedundancyOptimizer5.3.3 Join
  • 5.3.1 Joinable
  • 5.3.2JoinHook5.3.2.1 ZeroRedundancyOptimizer
  • 5.3.2.1 ZeroRedundancyOptimizer
  • 5.3.3 Join
  • 5.4 例子
  • 0xFF 参考

0x00 摘要

因为前文已经围绕Reducer相关的各种成员变量做了相关分析,所以本文开始做动态逻辑分析,目的是:把前面几篇文章串联起来,为后面分析前向传播和反向传播设定基础。

本系列其他文章如下:

深度学习利器之自动微分(1)

深度学习利器之自动微分(2)

[源码解析]深度学习利器之自动微分(3) — 示例解读

[源码解析]PyTorch如何实现前向传播(1) — 基础类(上)

[源码解析]PyTorch如何实现前向传播(2) — 基础类(下)

[源码解析] PyTorch如何实现前向传播(3) — 具体实现

[源码解析] Pytorch 如何实现后向传播 (1)—- 调用引擎

[源码解析] Pytorch 如何实现后向传播 (2)—- 引擎静态结构

[源码解析] Pytorch 如何实现后向传播 (3)—- 引擎动态逻辑

[源码解析] PyTorch 如何实现后向传播 (4)—- 具体算法

[源码解析] PyTorch 分布式(1)——历史和概述

[源码解析] PyTorch 分布式(2) —– DataParallel(上)

[源码解析] PyTorch 分布式(3) —– DataParallel(下)

[源码解析] PyTorch 分布式(4)——分布式应用基础概念

[源码解析] PyTorch分布式(5) —— DistributedDataParallel 总述&如何使用

[源码解析] PyTorch分布式(6) —DistributedDataParallel — 初始化&store

[源码解析] PyTorch 分布式(7) —– DistributedDataParallel 之进程组

[源码解析] PyTorch 分布式(8) ——– DistributedDataParallel之论文篇

[源码解析] PyTorch 分布式(9) —– DistributedDataParallel 之初始化

[源码解析] PyTorch 分布式(10)——DistributedDataParallel 之 Reducer静态架构

0x01 引论

为了更好的分析,我们还是需要看看如何调用。

1.1 调用

Reducer 的创建代码如下,是在_ddp_init_helper 之中。

        # Note: reverse list of buckets because we want to approximate the
        # order in which their gradients are produced, and assume they
        # are used in the forward pass in the order they are defined.
        self.reducer = dist.Reducer(
            parameters, # parameters[0]是张量列表
            list(reversed(bucket_indices)), # 桶信息
            self.process_group,
            expect_sparse_gradient,
            self.bucket_bytes_cap,
            self.find_unused_parameters,
            self.gradient_as_bucket_view,
            param_to_name_mapping,
        )

1.2 参数说明

调用的 parameters 举例如下, parameters[0] 就是 rank 0 上模型的 parameters,可以看到其只有 [0] 元素有意义,这个 [0] 原始本身包括 20 个元素:

parameters = {list: 1} 
0 = {list: 4}           
 0 = {Parameter: 10} Parameter containing:\ntensor([[-4.0381e-02,  3.8828e-02, 1  )   
 1 = {Parameter: 10} Parameter containing:\ntensor([-0.0438, -0.2033,  0.2771,  0.0721,  ) 
 2 = {Parameter: 5} Parameter containing:\ntensor([[-0.0094, -0.1319,  0.0713,  0.3155,  )
 3 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )
 ...
 20 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )                                                   
 __len__ = {int} 20
__len__ = {int} 1

bucket_indices 举例如下:

关于 tensor indices,就是给所有的tensor一个index,从0开始递增,一直到 tensors.size()。假如模型的 parameters 一共有20个张量,则 tensor index 从 0 到 19,分成 6 个buckets,则在这6个buckets之中,每个 tensor index 都是唯一不重复的。

+-----------------------------------------------------------------------+
|                                                                       |
|  <tensor index 0, tensor index 1, tensor index 2, tensor index 3>     |
|                                                                       |
|                                                                       |
|  <tensor index 4, tensor index 5, tensor 6>                           |
|                                                                       |
|                                                                       |
|  ......                                                               |
|                                                                       |
|                                                                       |
|  <tensor index 16, tensor index 17, tensor index 18, tensor index 19> |
|                                                                       |
+-----------------------------------------------------------------------+

接下来,我们就看看如何进行初始化 Reducer。

0x02 Reducer 初始化

代码位于:torch/lib/c10d/reducer.h 和 torch/lib/c10d/reducer.cpp

2.1 构造函数

具体逻辑如下:

  • 看看本模块是不是多设备模块,具体是: 遍历张量,得到张量的设备,把设备插入到一个set结构之中,如果set内的设备多于一个,是多设备
  • 如果 expect_sparse_gradients没有设置,就把expect_sparse_gradients_初始化为false。
  • 调用 initialize_buckets 初始化 buckets 并尽可能按照逆序将 parameters 分配到 buckets 之中,这样按桶通信就可以提高效率。后续在运行时候也可能再次重新初始化桶。
  • 为每个 parameter 加上 grad_accumulator,它们在 backward 时负责梯度同步。

    因为这些variables是autograd图的叶子张量,所以它们的grad_fn都被设置为 gradient accumulation function。
    Reducer保存了指向这些functions的指针,这样Reducer就可以知道它们在autograd传播之中是否被使用,如果没有使用,那么就把这些functions的梯度张量(grad tensors)设置为规约就绪状态。
    遍历张量,为每个张量生成一个类型为VariableIndex的变量index。
    得到Variable::AutogradMeta的grad_accumulator_,即用于累加叶子 Variable 的梯度累加器。
    把reducer的autograd_hook函数添加进去每个grad_accumulator_之中,变量index是hook的参数。这个 hook 挂在 autograd graph 之上,在 backward 时负责梯度同步。grad_accumulator 执行完后,autograd_hook 就会运行。

  • 因为这些variables是autograd图的叶子张量,所以它们的grad_fn都被设置为 gradient accumulation function。
  • Reducer保存了指向这些functions的指针,这样Reducer就可以知道它们在autograd传播之中是否被使用,如果没有使用,那么就把这些functions的梯度张量(grad tensors)设置为规约就绪状态。
  • 遍历张量,为每个张量生成一个类型为VariableIndex的变量index。
  • 得到Variable::AutogradMeta的grad_accumulator_,即用于累加叶子 Variable 的梯度累加器。
  • 把reducer的autograd_hook函数添加进去每个grad_accumulator_之中,变量index是hook的参数。这个 hook 挂在 autograd graph 之上,在 backward 时负责梯度同步。grad_accumulator 执行完后,autograd_hook 就会运行。
  • gradAccToVariableMap_ 存了grad_accumulator & index 的对应关系(函数指针和参数张量的对应关系),这样以后在 autograd graph 遍历寻找 unused parameters 就方便了。
  • 初始化 backward_stats_。
  • 调用 initialize_local_used_map 初始化各种 unused map。
// The constructor takes a list of variables for every model replica.
// The bucket assignment for this reducer is specified as a list of
// buckets, each of which is specified as a list of indices into the
// variables list for **a single replica** (i.e. `variables[0]`).
Reducer::Reducer(
    std::vector<std::vector<at::Tensor>> replicas, // 张量
    std::vector<std::vector<size_t>> bucket_indices, // 桶信息
    c10::intrusive_ptr<c10d::ProcessGroup> process_group,
    std::vector<std::vector<bool>> expect_sparse_gradients,
    int64_t bucket_bytes_cap,
    bool find_unused_parameters,
    bool gradient_as_bucket_view,
    std::unordered_map<size_t, std::string> paramNames)
    : replicas_(std::move(replicas)),
      process_group_(std::move(process_group)),
      expect_sparse_gradients_(std::move(expect_sparse_gradients)),
      expect_autograd_hooks_(false),
      require_finalize_(false),
      next_bucket_(0),
      has_marked_unused_parameters_(false),
      find_unused_parameters_(find_unused_parameters),
      gradient_as_bucket_view_(gradient_as_bucket_view),
      local_used_maps_reduced_(false),
      num_iterations_(0),
      num_buckets_ready_(0),
      has_rebuilt_bucket_(false),
      bucket_bytes_cap_(bucket_bytes_cap),
      divFactor_(kUnsetDivFactor),
      static_graph_(false),
      comm_hook_(nullptr),
      thread_local_state_(at::ThreadLocalState()),
      ddp_debug_level_(parseDistDebugLevel()),
      param_names_(std::move(paramNames)) {

  // Check whether the module is multi_device_module
  // 看看本模块是不是多设备模块
  {
    std::set<int> unique_devices;
    for (const auto& v : replicas_[0]) { // 遍历张量
      auto device_idx = int(v.device().index()); // 得到张量的设备
      if (unique_devices.find(device_idx) == unique_devices.end()) {
        unique_devices.insert(device_idx); // 把设备插入到一个set结构之中
        if (unique_devices.size() > 1) { // 如果set内的设备多于一个,是多设备
          is_multi_device_module_ = true; 
          break;
        }
      }
    }
  }

  // If `expect_sparse_gradients` is not specified, initialize it such that
  // we do not expect sparse gradients for any parameter.
  if (expect_sparse_gradients_.empty()) {
    expect_sparse_gradients_ = std::vector<std::vector<bool>>(
        replicas_.size(), std::vector<bool>(replicas_[0].size(), false));
  }

  // Initialize variable bucketing.
  // This can be reinitialized later after capturing runtime information.
  {
    std::lock_guard<std::mutex> lock(mutex_);
    initialize_buckets(std::move(bucket_indices)); //初始化桶
  }

  // All variables are expected to have their `grad_fn` set to the gradient
  // accumulation function (since they are leafs in the autograd graph).
  // We store pointers to these functions such that we can check if they are
  // used in an autograd pass. If they are not, we know their grad tensors
  // can be marked as ready for reduction.
  {
    const auto replica_count = replicas_.size();
    grad_accumulators_.resize(replica_count);
    for (size_t replica_index = 0; replica_index < replica_count; // 只有replicas_[0]有意义
         replica_index++) {
      const auto variable_count = replicas_[replica_index].size(); //张量数目
      grad_accumulators_[replica_index].resize(variable_count); // 给grad_accumulators_分配内存
        
      for (size_t variable_index = 0; variable_index < variable_count;
           variable_index++) { // 遍历张量,variable_index 就是张量的index
        auto& variable = replicas_[replica_index][variable_index]; //得到具体的张量
        const auto index = VariableIndex(replica_index, variable_index); //每个张量生成一个VariableIndex

        // The gradient accumulator function is lazily initialized once.
        // Therefore we can use its presence in the autograd graph as
        // evidence that the parameter has participated in an iteration.
        auto grad_accumulator =
            torch::autograd::impl::grad_accumulator(variable); // 得到Variable::AutogradMeta的grad_accumulator_,即,用于累加叶子 Variable 的梯度累加器

#ifndef _WIN32
        using torch::distributed::autograd::ThreadLocalDistAutogradContext;
#endif
        // Hook to execute after the gradient accumulator has executed.
        hooks_.emplace_back(
            // 累加器添加hook,这个 hook 挂在 autograd graph 之上,在 backward 时负责梯度同步。
            // grad_accumulator 执行完后,autograd_hook 就会运行
            grad_accumulator->add_post_hook(
                torch::make_unique<torch::autograd::utils::LambdaPostHook>(
                    [=](const torch::autograd::variable_list& outputs,
                        const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
                      this->rpc_context_.set(
                          ThreadLocalDistAutogradContext::getContextPtr());
#endif
                      this->autograd_hook(index); // 把reducer的autograd_hook函数添加进去
                      return outputs;
                    })),
            grad_accumulator);

        // Map raw function pointer to replica index and parameter index.
        // This is used later on when the autograd graph is traversed
        // to check for parameters for which no gradient is computed, if
        // find_unused_parameters=True.
        // Note that the mapping of gradient accumulator to variable should be
        // one to one as we deduplicate shared parameters before constructing
        // Reducer.
          
        // gradAccToVariableMap_ 存了grad_accumulator & index 的对应关系(函数指针和参数张量的对应关系),这样以后在 autograd graph 遍历寻找 unused parameters 就方便了
        if (find_unused_parameters_) {
          gradAccToVariableMap_[grad_accumulator.get()] = index;
        }

        numGradHooksTriggeredMap_[index] = 0;

        // The gradient accumulator is stored as weak_ptr in the autograd
        // metadata of the variable, so we have to keep it alive here for
        // the raw pointer to be valid.
        TORCH_CHECK(
            grad_accumulators_[replica_index][variable_index] == nullptr,
            c10::str(
                "Reducer tried to register duplicate grad accumulator for replica ",
                replica_index,
                " variable ",
                variable_index));
        grad_accumulators_[replica_index][variable_index] =
            std::move(grad_accumulator);
      }
    }
  }

  // Initialize backward stats vector.
  {
    const auto replica_count = replicas_.size();
    backward_stats_.resize(replica_count);
    const auto variable_count = replicas_[0].size();
    std::for_each(
        backward_stats_.begin(),
        backward_stats_.end(),
        [=](std::vector<int64_t>& v) { v.resize(variable_count); });
  }

  // See Note [Skip allreducing local_used_maps_dev]
  if (find_unused_parameters_) {
    initialize_local_used_map();
  }
}

我们接下来具体分析每一个部分。

2.2 初始化桶

initialize_buckets方法用来初始化桶,具体逻辑是对于每一个桶,添加其模型副本,对于每一个模型副本,添加张量列表:

  • 用分布式上下文设置 rpc_context_。

    如果在DDP构造函数内调用initialize_bucket,则 rpc上下文指针(rpc context ptr)是否为null 无关紧要,因为grad不会发生变化。
    如果在训练循环期间调用initialize_bucket,例如在rebuild_bucket 内部,因为grad可能会发生改变并指向bucket_view,那么它需要检查rpc context ptr是否为null。
    如果rpc context ptr是null,则改变 variable.grad(),否则,在rpc上下文中改变梯度。

  • 如果在DDP构造函数内调用initialize_bucket,则 rpc上下文指针(rpc context ptr)是否为null 无关紧要,因为grad不会发生变化。
  • 如果在训练循环期间调用initialize_bucket,例如在rebuild_bucket 内部,因为grad可能会发生改变并指向bucket_view,那么它需要检查rpc context ptr是否为null。
  • 如果rpc context ptr是null,则改变 variable.grad(),否则,在rpc上下文中改变梯度。
  • 清空buckets_ 和 variable_locators_。
  • 重置variable_locators_的尺寸,这样每个variable都有一个bucket index。
  • 利用如下得到所有桶的个数和每个桶中副本个数:bucket_count = bucket_indices.size(); replica_count = replicas_.size();
  • 从0开始递增到 bucket_count,逐一初始化 Bucket。

    生成一个 Bucket bucket
    如果bucket_indices[bucket_index].size() == 1,说明这个桶期待一个single sparse gradient,则设置 bucket.expect_sparse_gradient = true。
    从0开始递增到replica_count,逐一初始化 BucketReplica。

    生成一个 BucketReplica replica
    如果这个桶期待一个single sparse gradient,则

    利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
    利用 variable_index 得到副本之中对应的variable。
    设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。

    否则说明是dense gradient,则

    遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
    设置variable的设备和数据类型
    给副本设置其variables,代码为:replica.variables.push_back(variable)。
    设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
    给relica.contents分配内存
    利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
    利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。

    遍历桶中的variable,代码为 bucket_indices[bucket_index]。

    设置 Reducer.variable_locators_,这样 Reducer 就知道如何在 bucket 之中确定一个varaible。bucket_index 是buckets_列表的位置,表示 buckets_ 之上的一个bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。

    设置桶的变量,bucket.variable_indices = std::move(bucket_indices[bucket_index]);
    利用 buckets_.push_back(std::move(bucket)) 把bucket这个桶加入到 Reducer之中。

  • 生成一个 Bucket bucket
  • 如果bucket_indices[bucket_index].size() == 1,说明这个桶期待一个single sparse gradient,则设置 bucket.expect_sparse_gradient = true。
  • 从0开始递增到replica_count,逐一初始化 BucketReplica。

    生成一个 BucketReplica replica
    如果这个桶期待一个single sparse gradient,则

    利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
    利用 variable_index 得到副本之中对应的variable。
    设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。

    否则说明是dense gradient,则

    遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
    设置variable的设备和数据类型
    给副本设置其variables,代码为:replica.variables.push_back(variable)。
    设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
    给relica.contents分配内存
    利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
    利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。

  • 生成一个 BucketReplica replica
  • 如果这个桶期待一个single sparse gradient,则

    利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
    利用 variable_index 得到副本之中对应的variable。
    设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。

  • 利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
  • 利用 variable_index 得到副本之中对应的variable。
  • 设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。
  • 否则说明是dense gradient,则

    遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
    设置variable的设备和数据类型
    给副本设置其variables,代码为:replica.variables.push_back(variable)。
    设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
    给relica.contents分配内存
    利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
    利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。

  • 遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
  • 设置variable的设备和数据类型
  • 给副本设置其variables,代码为:replica.variables.push_back(variable)。
  • 设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
  • 给relica.contents分配内存
  • 利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
  • 利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。
  • 遍历桶中的variable,代码为 bucket_indices[bucket_index]。

    设置 Reducer.variable_locators_,这样 Reducer 就知道如何在 bucket 之中确定一个varaible。bucket_index 是buckets_列表的位置,表示 buckets_ 之上的一个bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。

  • 设置 Reducer.variable_locators_,这样 Reducer 就知道如何在 bucket 之中确定一个varaible。bucket_index 是buckets_列表的位置,表示 buckets_ 之上的一个bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。
  • 设置桶的变量,bucket.variable_indices = std::move(bucket_indices[bucket_index]);
  • 利用 buckets_.push_back(std::move(bucket)) 把bucket这个桶加入到 Reducer之中。

具体代码是:

void Reducer::initialize_buckets(
    std::vector<std::vector<size_t>> bucket_indices) {
  // If initialize_buckets is called inside DDP constructor, then
  // it does not matter rpc context ptr is nullptr or not, as grad
  // will not be mutated.
  // If initialize_buckets is called during training loop, e.g, inside
  // rebuild_buckets(), since grad could be mutated and be pointed to
  // bucket_view, then it needs to check rpc context ptr is nullptr or not,
  // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
  // mutate grad in rpc context.
#ifndef _WIN32
  using torch::distributed::autograd::ThreadLocalDistAutogradContext;
  this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
#endif

  // This shouldn't be called if we're expecting autograd hooks to fire.
  TORCH_CHECK(
      !expect_autograd_hooks_,
      "`initialize_buckets` must NOT be called during autograd execution.");

  // Clear current bucket assignment.
  buckets_.clear();
  variable_locators_.clear();

  // Ensure we have a bucket index for every variable.
  variable_locators_.resize(replicas_[0].size());

  // Iterate over buckets.
  const auto bucket_count = bucket_indices.size();
  const auto replica_count = replicas_.size();
  buckets_.reserve(bucket_count);
  // 从0开始递增到bucket_count
  for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
    Bucket bucket; // 生成一个桶

    // TODO(@pietern): Validate indices.
    // Must be non-empty, unique, and unique across buckets.
    TORCH_CHECK(
        bucket_indices[bucket_index].size() > 0, "Empty bucket specified.");

    // Variables that expect sparse gradients must have their own bucket.
    if (bucket_indices[bucket_index].size() == 1) {
      // 说明这个桶期待一个single sparse gradient
      const auto variable_index = bucket_indices[bucket_index].front();
      bucket.expect_sparse_gradient =
          expect_sparse_gradients_[0][variable_index];
    } else {
      for (const auto variable_index : bucket_indices[bucket_index]) {
        TORCH_CHECK(
            !expect_sparse_gradients_[0][variable_index],
            "Buckets with more than one variable cannot include variables ",
            "that expect a sparse gradient.");
      }
    }

    // Iterate over model replicas. 从0开始递增到replica_count,遍历模型副本数目,为每一个模型副本都要做同样设置
    for (size_t replica_index = 0; replica_index < replica_count;
         replica_index++) {
      BucketReplica replica; // 生成一个副本

      if (bucket.expect_sparse_gradient) {
        // 说明这个桶期待一个single sparse gradient
        const auto variable_index = bucket_indices[bucket_index].front(); // 得到张量的index
        const auto& variable = replicas_[replica_index][variable_index]; // 得到张量
        TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1);
        replica.variables = {variable}; // 这个副本只包括一个variable
      } else {
        at::TensorOptions options;
        // The start index of the variable in the flattened tensor.
        size_t offset = 0;

        // Reserve enough space for the per-variable fields stored in bucket
        // replica for efficiency.
        const size_t num_variables = bucket_indices[bucket_index].size();
        replica.variables.reserve(num_variables); 
        replica.offsets.reserve(num_variables);
        replica.lengths.reserve(num_variables);
        replica.sizes_vec.reserve(num_variables);

        // Iterate over bucket variables.
        for (const auto variable_index : bucket_indices[bucket_index]) { //遍历桶中的variable
          TORCH_CHECK(
              variable_index < replicas_[replica_index].size(),
              "Out of range variable index specified.");
          const auto& variable = replicas_[replica_index][variable_index];
          if (!options.has_device()) {
            options = options.device(variable.device());
          } else {
            TORCH_CHECK(
                variable.device() == options.device(),
                "All parameters in a bucket must be ",
                "placed on the same device.");
          }
          if (!options.has_dtype()) {
            options = options.dtype(variable.dtype());
          } else {
            TORCH_CHECK(
                variable.dtype() == options.dtype(),
                "All parameters in a bucket must have the same dtype.");
          }
          
          const auto length = variable.numel();
          // 给副本设置其variables
          replica.variables.push_back(variable); // 这里添加了一个新变量,所以最终能知道该桶中的变量数目
          // 设置replica 的一些关于variable的元信息
          replica.offsets.push_back(offset);
          replica.lengths.push_back(length);
          replica.sizes_vec.push_back(variable.sizes());
          offset += length;
        }

        // Allocate bucket contents tensor.
        replica.contents = at::empty({static_cast<long>(offset)}, options);

        initialize_bucket_views(replica, replica.contents); // 初始化cotents和views
      }

      // Add bucket replica to enclosing bucket.
      bucket.replicas.push_back(std::move(replica)); // 桶的副本列表中添加一个新副本
    }

    // Map participating variables to this bucket.
    // This is identical across replicas so we only need to do this once.
    size_t intra_bucket_index = 0;
    for (const auto variable_index : bucket_indices[bucket_index]) { // 遍历桶中的variable
      TORCH_CHECK(
          variable_index < variable_locators_.size(),
          "Out of range variable index specified.");
      variable_locators_[variable_index] = // 这样 Reducer 就知道如何在 bucket 之中确定一个varaible
          VariableLocator(bucket_index, intra_bucket_index++);
    }
    bucket.variable_indices = std::move(bucket_indices[bucket_index]);

    buckets_.push_back(std::move(bucket)); // 把桶插入Reducer
  }
}

2.3 初始化视图

initialize_bucket_views 这里是设置 Replica 的contents 和 views。

// (see Note:  "Gradient Layout Contract" in initialize_buckets).
void Reducer::initialize_bucket_views(
    Reducer::BucketReplica& replica,
    at::Tensor& contents) {
  for (size_t i = 0; i < replica.variables.size(); i++) {
    auto& v = replica.variables[i];
    const auto offset = replica.offsets[i];
    const auto length = replica.lengths[i];
    if (v.is_non_overlapping_and_dense()) { // Dense类型的张量
      // If the param's memory is dense, match its layout, anticipating
      // the autograd engine (AccumulateGrad) will also create gradients
      // matching its layout.
      replica.bucket_views_in.push_back( // replica.bucket_views_in里面都是视图
          contents.as_strided(v.sizes(), v.strides(), offset));
    } else { // Sparse类型的张量
      // Fall back to a C-style contiguous view, again anticipating
      // AccumulateGrad will do the same when stashing grads for non-dense
      // params.
      replica.bucket_views_in.push_back( // replica.bucket_views_in里面都是视图
          contents.narrow(0, offset, length).view(v.sizes()));
    }
    // By default `bucket_views_out` and `bucket_views_in` are
    // essentially the same thing.
    replica.bucket_views_out = replica.bucket_views_in; // out也是视图

    // If gradient_as_bucket_view_ is set as true, then there are two cases to
    // handle: initialize_bucket_views could be called inside initialize_buckets
    // when rebuild_buckets, if grad has already been defined/calculated in
    // previous iteration, old grad needs to be copied into new bucket_view and
    // let grad point to the new bucket_view, initialize_bucket_views could also
    // be called inside initialize_buckets during construction. Grads are not
    // defined during construction time, in this case, do not let grad point to
    // bucket_view, because grads should be kept as being undefined for globally
    // unused parameters.
    if (gradient_as_bucket_view_) {
      auto& bucket_view = replica.bucket_views_in.back();
      runGradCallbackForVariable(v, [&](auto& grad) {
        if (grad.defined() && !grad.is_alias_of(bucket_view)) {
          bucket_view.copy_(grad);
          grad = bucket_view; // 梯度被修改了,需要回写
          // The grad is modefied and needs to be written back.
          return true;
        }
        // The grad is not modified and does not need to be written back.
        return false; // 不需要回写,因为没有被修改
      });
    }
  }
}

2.3.1 BucketReplica成员变量

我们先回忆一下BucketReplica的几个成员变量。

  • at::Tensor contents :把桶的内容展平的结果,即Flattened (1 dimensional) 之后的结果。
  • std::vector bucket_views_in :提供了从输入角度在 contents 之中查看具体梯度的方法。
  • std::vector bucket_views_out :提供了从输入角度在 contents 之中查看具体梯度的方法。

关于 和 的进一步说明:

std::vector<at::Tensor> bucket_views_in
std::vector<at::Tensor> bucket_views_out
  • 这两个变量提供在 contents 之中操作具体梯度的方法,或者说,它们提供了视图(views),该视图可以操作contents 之中每个张量的梯度。用户把这两个变量作为入口点来把每个梯度的数据从 content 之中移入和移出。
  • 在 PyTorch 之中,视图是指创建一个方便查看的东西,视图与原数据共享内存,它只是将原有的数据进行整理,直接显示其中部分内容或者进行重排序后再显示出来。

也需要对几个 PyTorch 函数进行说明。

  • as_strided :依据现有tensor以及给定的步长来创建一个视图(类型仍然为tensor),需要注意,这里的结果是视图,所以这个张量依然和原始张量共享内存。
  • narrow :返回一个新的张量,其是原来张量的缩小版,但是这个张量依然和原始张量共享内存。

BucketReplica 逻辑具体如下图:

+------------------------------------------+
| BucketReplica                            |
|                                          |
|       vector<Tensor> bucket_views_in +--------------------+
|                                          |                |
|                                          |                |
|       vector<Tensor> bucket_views_out +--------------+    |
|                                          |           |    |
|                                          |           |    |
|                                          |           v    v
|                                          |     +-----+----+--------------------------+
|       Tensor contents  +---------------------> |Flattened (Tensor1, Tensor2, Tensor3)|
|                                          |     +-------------------------------------+
|                                          |
|                                          |
|       vector<Tensor> variables  +------------>  [Tensor1,Tensor2,Tensor3]
|                                          |
|                                          |
|                                          |
+------------------------------------------+

2.3.2 调用

如何调用?如果设置为true,则有两种情况需要处理:

gradient_as_bucket_view_
  • rebuild_buckets 之中可以在initialize_bucket内调用initialize_bucket_view,如果grad在上一次迭代中已经定义/计算过,则需要将旧的grad复制到新的bucket_view中,并让grad指向新的bucket_view,
  • 在构造过程中,也可以在initialize_bucket中调用initialize_bucket_views。在构造期间不会定义梯度,在这种情况下,不要让梯度指向bucket_view,因为对于全局未使用的参数,梯度应保持为未定义。

2.4 初始化本地使用变量

initialize_local_used_map此处是初始化 ,我们回忆一下论文内容, 就是用来查找全局未使用参数(Globally Unused Parameters):

local_used_maps_
local_used_maps_

全局未使用参数(Globally Unused Parameters)的梯度在向前和向后过程中应保持不变。检测未使用的参数需要全局信息,因为在一个DDP过程中,一个参数可能在一次操作中不存在,但可能在另一个过程的同一次迭代中参与训练。因此DDP在位图中维护本地未使用的参数信息,并启动额外的AllReduce以收集全局位图。由于位图比张量尺寸小得多,因此模型中的所有参数共享同一位图,而不是创建每桶位图(per-bucket bitmaps)。位图位于CPU上,以避免为每次更新启动专用CUDA内核。但是,某些ProcessGroup后端可能无法在CPU 张量上运行AllReduce。例如,ProcessGroupNCCL仅支持CUDA张量。此外,由于DDP应该与任何定制的ProcessGroup后端一起工作,它不能假设所有后端都支持CPU张量。为了解决这个问题,DDP在同一设备上维护另一个位图作为第一个模型参数,并调用非阻塞拷贝操作(non-blocking copy)将CPU位图移动到设备位图以进行集合通信。

全局未使用参数(Globally Unused Parameters)的梯度在向前和向后过程中应保持不变。检测未使用的参数需要全局信息,因为在一个DDP过程中,一个参数可能在一次操作中不存在,但可能在另一个过程的同一次迭代中参与训练。因此DDP在位图中维护本地未使用的参数信息,并启动额外的AllReduce以收集全局位图。由于位图比张量尺寸小得多,因此模型中的所有参数共享同一位图,而不是创建每桶位图(per-bucket bitmaps)。位图位于CPU上,以避免为每次更新启动专用CUDA内核。但是,某些ProcessGroup后端可能无法在CPU 张量上运行AllReduce。例如,ProcessGroupNCCL仅支持CUDA张量。此外,由于DDP应该与任何定制的ProcessGroup后端一起工作,它不能假设所有后端都支持CPU张量。为了解决这个问题,DDP在同一设备上维护另一个位图作为第一个模型参数,并调用非阻塞拷贝操作(non-blocking copy)将CPU位图移动到设备位图以进行集合通信。

具体代码如下:

void Reducer::initialize_local_used_map() {
  const auto replica_count = replicas_.size();
  const auto variable_count = replicas_[0].size();
  local_used_maps_.resize(replica_count);
  local_used_maps_dev_.resize(replica_count);

  for (size_t i = 0; i < replica_count; i++) {
    at::TensorOptions options;
    options = options.dtype(at::kInt);

    // Deliberately don't pin the memory even if local_used_maps_dev_ will
    // be cuda. See Note [local_used_maps_ -> local_used_maps_dev copying]
    local_used_maps_[i] =
        at::zeros({static_cast<long>(variable_count)}, options);

    // This tensor needs to be on the same device as replica because backend
    // such as NCCL may not support CPU tensors, and hence it might not work
    // if we always put it on CPU.
    options = options.device(replicas_[i][0].device());
    local_used_maps_dev_[i] =
        at::empty({static_cast<long>(variable_count)}, options);
  }
}

初始化流程大致如下:

                                    +
                                    |
                                    |
                                    v
                  rpc_context_ = ThreadLocalDistAutogradContext
                                    +
                                    |
                                    |
                                    v
                  buckets_ & variable_locators_ (clear & resize)
                                    +
                                    |
                                    |
                                    v
+----------------------->  from 0 ~ bucket_count :  +--------------------------->
|                                                                                +
|                                                                                |
|      +-------------------------------------------------------------------+     |
|      | init Bucket          set bucket_indices                           |     |
|      |                            +                                      |     |
|      |                            |                                      |     |
|      |                            |                                      |     |
|      |                            v                                      |     |
|      |   ^ +------------> from 0 ~ replica_count : +----------------->   |     |
|      |   |                                                           |   |     |
|      |   |  +---------------------------------------------------+    |   |     |
|      |   |  | init BucketReplica                                |    |   |     |
|      |   |  |                                                   |    |   |     |
<----+ |   +--+                                                   | <--+   | <---+
       |      |    bucket.replicas.push_back(std::move(replica))  |        |
       |      |                                                   |        |
       |      +----------------------+----------------------------+        |
       |                             |                                     |
       |                             |                                     |
       |                             v                                     |
       |             buckets_.push_back(std::move(bucket))                 |
       |                             +                                     |
       +-------------------------------------------------------------------+
                                     |
                                     v

得到的 Reducer 大致如下,这里需要注意的是 ,BucketReplica 每个桶只有一个:

            +----------------------------------------+                 +------------------+
            |tensor index 4, tensor index 5, tensor 6| <------+        | index 2, index 3 |
            +----------------------------------------+        |        +--------------+---+
                                                              |                       ^
                                                              |                       |
+---------------------------+   +---------------------------------------------------------+
| Reducer                   |   | +----------------------------------+     +------------+ |
|                           |   | |Bucket                     |      |     |Bucket    | | |
|                           |   | |                           +      |     |          | | |
| vector<Bucket> buckets_ +---> | | vector<size_t> variable_indices  |     | indices ++ | |
|                           |   | |                                  |     |            | |
|                           |   | |  vector<BucketReplica> replicas  | ... | replicas   | |
|                           |   | |                         +        |     |   +        | |
|                           |   | |                         |        |     |   |        | |
|                           |   | +----------------------------------+     +------------+ |
|                           |   |                           |                  |          |
+---------------------------+   +---------------------------------------------------------+
                                                            |                  |
                                                            |                  |
                                                            v                  v
                          +---------------------------------------+   +-------------------+
                          |  +----------------------------------+ |   | +---------------+ |
                          |  | BucketReplica                    | |   | | BucketReplica | |
                          |  |                                  | |   | |               | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_in  | |   | |   views_in    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_out | |   | |   views_out   | |
                          |  |                                  | |   | |               | |
                          |  |  Tensor contents                 | |   | |   contents    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> variables        | |   | |   variables   | |
                          |  |                     +            | |   | |      +        | |
                          |  +----------------------------------+ |   | +---------------+ |
                          +---------------------------------------+   +-------------------+
                                                   |                           |
                                                   |                           |
                                                   v                           v
                                   +---------------+------------+    +---------+----------+
                                   |Tensor 4, Tensor 5, Tensor 6|    | Tensor 2, Tensor 3 |
                                   +----------------------------+    +--------------------+

0x03 静态图

3.1 缘由

虽然 PyTorch 是动态图,但是用户可以明确地让DDP知道训练图是静态的,有如下情况时候可以设定:

  • 已使用和未使用的参数集在整个训练循环中不变,在这种情况下,用户是否将find_unsued_parameters设置为true并不重要。
  • 图形的训练方式在整个训练循环过程中不会改变(意味着不存在依赖于迭代的控制流)。当图被设置为静态时,DDP将支持以前不支持的case,比如:

    可重入的反向传播。
    多次activation checkpointing。
    activation checkpointing 并且find_unused_parameters = true。
    并不是所有的输出张量都用于损失计算。。
    在前向函数之外有一个模型参数。
    当find_unsued_parameters=true时或者存在未使用的参数,可能会提高性能,因为DDP在每个迭代之内不会搜索网络来检查未使用的参数。

  • 可重入的反向传播。
  • 多次activation checkpointing。
  • activation checkpointing 并且find_unused_parameters = true。
  • 并不是所有的输出张量都用于损失计算。。
  • 在前向函数之外有一个模型参数。
  • 当find_unsued_parameters=true时或者存在未使用的参数,可能会提高性能,因为DDP在每个迭代之内不会搜索网络来检查未使用的参数。

3.2 使用

_set_static_graph 可以配置静态图,此API应在DistributedDataParallel构造之后,并且在训练循环开始之前调用。并且,也应该以同样的方式对所有的rank 进行调用。例如:

ddp_model = DistributedDataParallel(model)
ddp_model._set_static_graph()
for i in range(n):

_set_static_graph 代码为:

def _set_static_graph(self):
    """
    Users can explicitly let DDP know the trained graph is static,
    when 1) the set of used and unused parameters will not change
    during the whole training loop; in this case, it does not matter
    whether users set find_unsued_parameters = true or not.
    2) how the graph is trained will not change during the whole training
    loop (meaning there is no control flow depending on iterations).
    When graph is set to be static, DDP will support cases that can not
    be supported in the past: 1) reentrant backwards
    2) activation checkpointing multiple times 3)
    activation checkpointing with find_unused_parameters = true.
    4) not all output tensors are used in loss calculation.
    5) there is model parameter that is outside of forward function.
    6) potentially improve performance when find_unsued_parameters = true
    or there are unused parameters, as DDP will not search graph in each
    iteraton to detect unused parameters when static_graph is set to be True.

    This API should be called after DistributedDataParallel construction, and
    before training loops starts. Also it should be called in the same way for
    all ranks. For example:
        ddp_model = DistributedDataParallel(model)
        ddp_model._set_static_graph()
        for i in range(n):
            .....
    """
    self.static_graph = True
    self.reducer._set_static_graph() # 调用 Reducer 进行配置
    self.logger._set_static_graph()
    if self.find_unused_parameters:
        warnings.warn(
            "You passed find_unused_parameters=true to DistributedDataParallel, "
            "`_set_static_graph` will detect unused parameters automatically, so "
            "you do not need to set find_unused_parameters=true, just be sure these "
            "unused parameters will not change during training loop while calling "
            "`_set_static_graph`."
        )

3.2 Reducer

Reducer 只有在第一次迭代之后才能生成静态图,因为毕竟PyTorch还是动态的,无论如何也得走一步动态生成。

void Reducer::set_static_graph() {
  std::lock_guard<std::mutex> lock(mutex_);
  TORCH_CHECK(
      num_iterations_ == 0,
      "set_static_graph() should be called before training loop starts "
      "and after DistributedDataParallel is constructed.");
  static_graph_ = true;
  // when static_graph_ is set as true, always initialize_local_used_map
  // and detect the global unused parameters in the first iteration.
  initialize_local_used_map();
}

0x04 重建桶

4.1 为何要重建

因为 PyTorch 是动态生成计算图,所以需要相应重建桶。但是只有设置了静态图 并且 第一次迭代之后才会重建,如果设置 find_unused_parameters_,就不重建。

  // Returns true if we should rebuild buckets, else false. We only rebuild
  // buckets once after the first iteration and never rebuild them if
  // find_unused_parameters_.
  inline bool should_rebuild_buckets() const {
    return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;
  }

4.2 准备重建

我们首先看看重建之前的一些准备。

push_rebuilt_params 就是插入一个重建参数列表。

void Reducer::push_rebuilt_params(const VariableIndex& index) {
  rebuilt_params_.push_back(
      replicas_[index.replica_index][index.variable_index]);
  rebuilt_param_indices_.push_back(index.variable_index);
}

其次,push_rebuilt_params_for_all_indices 会遍历每个 replica,针对 replica 之中的每个 variable 进行设置。

void Reducer::push_rebuilt_params_for_all_indices() {
  std::lock_guard<std::mutex> lock(mutex_);
  if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
    return;
  }
  const auto replica_count = replicas_.size();
  for (size_t replica_index = 0; replica_index < replica_count;
       ++replica_index) {
    const auto variable_count = replicas_[replica_index].size();
    for (size_t variable_index = 0; variable_index < variable_count;
         ++variable_index) {
      const auto index = VariableIndex(replica_index, variable_index);
      push_rebuilt_params(index);
    }
  }
}

4.3 重建

我们接下来看看重建机制。

DDP 根据张量在后向传播中接收梯度的时间,使用 rebuilt_params_ 和 rebuilt_param_indices_ 来重建存储桶。

rebuild_buckets 函数进行广播通信调用,并且可以与下一个forward()调用重叠,因此它可以是异步的。

  • 在find_unused_parameters=true情况下重建bucket 就是异步操作,因为我们可以多次重建bucket,其中子图经过训练,参数索引顺序可能会更频繁地更改。
  • 对于find_unused_parameters=false的情况,bucket只重建一次,性能成本可以忽略不计。如果已重建存储桶, rebuild_buckets 则返回true。
bool Reducer::rebuild_buckets() {
  // Ensure reduction for previous backwards pass is finished. If user's model
  // has unused parameters for example, this will raise an error recommending to
  // run with find_unused_parameters=True, instead of the size mismatch
  // exception below.
  std::lock_guard<std::mutex> lock(mutex_);
  ensure_prior_reduction_finished();
  if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
    return false;
  }

  std::vector<std::vector<size_t>> rebuilt_bucket_indices;
  std::vector<size_t> bucket_size_limits;
  bucket_size_limits.push_back(kDefaultFirstBucketBytes);
  bucket_size_limits.push_back(bucket_bytes_cap_);
  rebuilt_bucket_indices = compute_bucket_assignment_by_size(
      rebuilt_params_,
      bucket_size_limits,
      expect_sparse_gradients_[0],
      rebuilt_param_indices_);

  // For rebuilt bucket indices, it needs to be synced across all ranks.
  // Broadcast the newly rebuilt bucket indices from rank 0 in default.
  // After syncing up rebuilt bucket indices, initialize buckets for reducer.
  sync_bucket_indices(rebuilt_bucket_indices);

  has_rebuilt_bucket_ = true; // 只重建一次
  rebuilt_params_.clear();
  rebuilt_param_indices_.clear();

  initialize_buckets(std::move(rebuilt_bucket_indices));
  return true;
}

4.4 何时设定重建

重建仅在以下情况进行设定:

  • 第一次重建存储桶
  • static_graph_ is true 或 find_unused_parameters_ is false
  • 此反向传播过程需要运行allreduce。

在这里,我们只需基于梯度到达顺序将张量及其参数索引转储到和 。然后在finalize_backward() 结束时,将基于和 重建存储桶,然后广播和初始化存储桶。

rebuilt_params_
rebuilt_param_indices_
rebuilt_params_
rebuilt_param_indices_

此外,我们只需要转储一个副本的张量和参数索引。

以 mark_variable_ready 为例,其中就会调用 push_rebuilt_params(index) 来插入列表。

void Reducer::mark_variable_ready(VariableIndex index) {
  // Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
  // static_graph_ is true or find_unused_parameters_ is false,
  // 3) this backward pass needs to run allreduce.
  // Here, we just dump tensors and their parameter indices into
  // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
  // order, and then at the end of finalize_backward(), buckets will be
  // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
  // will be broadcasted and initialized. Also we only need to dump tensors
  // and parameter indices of one replica.
  if (should_rebuild_buckets()) {
    push_rebuilt_params(index); // 插入列表
  }

  const auto replica_index = index.replica_index;
  const auto variable_index = index.variable_index;

  if (replica_index == 0) {
    checkAndRaiseMarkedTwiceError(variable_index);
    perIterationReadyParams_.insert(variable_index);
  }
  backward_stats_[replica_index][variable_index] =
      current_time_in_nanos() - cpu_timer_.backward_compute_start_time;

  // Any time we mark a variable ready (be it in line due to unused parameters,
  // or via an autograd hook), we require a call to the finalize function. If
  // this doesn't happen before the next iteration (or call to
  // `prepare_for_backwards`), we know something is wrong.
  require_finalize_ = true;

  const auto& bucket_index = variable_locators_[variable_index];
  auto& bucket = buckets_[bucket_index.bucket_index];
  auto& replica = bucket.replicas[replica_index];

  set_divide_factor();

  if (bucket.expect_sparse_gradient) {
    mark_variable_ready_sparse(index);
  } else {
    mark_variable_ready_dense(index);
  }

  // TODO(@pietern): Make this work for both CPU/CUDA tensors.
  // When using CPU tensors we don't need to do this.
  // // Record event so that we can wait for all of them.
  // auto& event = replica.events[bucket_index.intra_bucket_index];
  // event.record();

  // Check if this was the final gradient for this bucket.
  if (--replica.pending == 0) {
    // Kick off reduction if all replicas for this bucket are ready.
    if (--bucket.pending == 0) {
      mark_bucket_ready(bucket_index.bucket_index);
    }
  }

  // Run finalizer function and kick off reduction for local_used_maps once the
  // final bucket was marked ready.
  if (next_bucket_ == buckets_.size()) {

    if (dynamic_graph_find_unused()) {
      all_reduce_local_used_map();
    }

    // The autograd engine uses the default stream when running callbacks, so we
    // pass in the current CUDA stream in case it is not the default.
    const c10::Stream currentStream = get_current_stream();
    torch::autograd::Engine::get_default_engine().queue_callback([=] {
      std::lock_guard<std::mutex> lock(this->mutex_);
      // Run callback with the current stream
      c10::OptionalStreamGuard currentStreamGuard{currentStream};
      if (should_collect_runtime_stats()) {
        record_backward_compute_end_time();
      }
      // Check that all buckets were completed and had their work kicked off.
      TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
      this->finalize_backward();
    });
  }
}

4.5 直接调用

_rebuild_buckets 函数也可以直接调用,比如如下情况,就是在整个训练期间内在 forward 调用了一次。

def forward(self, *inputs, **kwargs):
    with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
        self.reducer.save_thread_local_state()
        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            self.num_iterations += 1
            self.reducer.prepare_for_forward()
        if self.ddp_uneven_inputs_config.ddp_join_enabled:
            ones = torch.ones(1, device=self.device)
            work = dist.all_reduce(ones, group=self.process_group, async_op=True)
            if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                # Active ranks schedule an allreduce with zeros, inactive
                # ranks schedule them with 1. If the result != 0 it
                # indicates at least one rank has terminated and we should
                # throw.
                zeros = torch.zeros(1, device=self.device)
                dist.all_reduce(zeros, group=self.process_group)
                should_throw_stop_iteration = zeros.item()
                if should_throw_stop_iteration:
                    raise RuntimeError(
                        "Detected at least one rank that exhausted inputs. Throwing across all ranks."
                    )
            else:
                self.reducer._set_forward_pass_work_handle(
                    work,
                    self.ddp_uneven_inputs_config.ddp_join_divide_by_initial_world_size,
                )

        # Calling _rebuild_buckets before forward compuation,
        # It may allocate new buckets before deallocating old buckets
        # inside _rebuild_buckets. To save peak memory usage,
        # call _rebuild_buckets before the peak memory usage increases
        # during forward computation.
        # This should be called only once during whole training period.
        
        # 在这里进行直接调用
        if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): # 设定
            logging.info("Reducer buckets have been rebuilt in this iteration.")

再比如 Join 方法也可以直接调用进行重建。

@contextmanager
def join(
    self,
    divide_by_initial_world_size=True,
    enable=True,
    throw_on_early_termination=False,
):
  
  									# 忽略其他代码
    
                    else:
                        # Some DDP process still needs to be joined.
                        if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                            # Schedule allreduce telling active ranks to terminate
                            ones = torch.ones(1, device=self.device)
                            dist.all_reduce(ones, group=self.process_group)
                            # Raising StopIteration doesn't throw error in python 3.6
                            # and throws RuntimeError in 3.7+ (PEP 479), so just
                            # raise RuntimeError here.
                            raise RuntimeError(
                                f"Rank {self._distributed_rank} exhausted all inputs."
                            )
                        if is_last_joiner:
                            is_last_joiner = False
                        # It will rebuild buckets only once during training period
                        
                        # 这里进行调用。
                        self.reducer._rebuild_buckets()
                        # Schedule a corresponding broadcast if we are syncing module
                        # buffers in the forward pass.
                        self._check_and_sync_module_buffers()   

既然提到了 Join,我们接下来就看看这个概念。

0x05 Join

Join 是为了解决训练数据不均匀的问题,就是允许某些输入较少的worker(其已经完成Join操作)可以继续和那些尚未结束的worker继续执行集合通信,就是一个欺骗操作(Shadow)。

5.1 缘起

支撑DDP背后的是几个集合通信库的all-reduce操作,其完成了各个worker之间的梯度同步。而当训练数据在 ranks 之间的输入是不均匀(uneven)的,就会导致DDP会挂起。因为集合通信要求在进程组中的所有rank都参与,因此如果一个rank的输入少,其他ranks会hang或者报错(取决于后端),而且任何类在执行同步集合通信时,在每次迭代都会遇到这个问题。

因此,DDP 给出了一个 “Join” API,是一个上下文管理器,在每个rank的训练循环之中使用。数据量少的 rank 会提前耗尽输入,这时它将给集合通信一个假象,从而会构建一个虚拟(dummy)的 all-reduce,以便在数据不足时候与其他 ranks 匹配。具体如何制造这个假象是由注册hook指定。

Join

其大致思路如下:

                +----------------------------+
                |             Data           |
                |   +--------+   +--------+  |
                |   |        |   | Empty  |  |
                |   |        |   |        |  |
                |   +-----+--+   +--------+  |
                |         |                  |
                |         |                  |
                +----------------------------+
                          |
                          |
        +------------+    |               +------------+
        |            |    |               |            |
+---->  |    Model   |    |               |   Model    | <-----+
|       |            |    |               |            |       |
|       +------+-----+    |               +------+-----+       |
|              |          |                      |             |
|              |          |                      |             |
|              v          |                      v             |
|       +------+-----+    |             +--------+----------+  |
|       |  Forward   +<---+             | _JoinHook         |  |
|       |  (local)   |                  |                   |  |
|       +------+-----+                  |                   |  |
|              |                        |                   |  |
|              |                        |                   |  |
|              v                        | +---------------+ |  |
|       +------+-----+                  | | main_hook     | |  |
|       |  Backward  |                  | |               | |  |
|       |  (local)   |                  | |               | |  |
|       +------+-----+                  | |               | |  |
|              |                        | |               | |  |
|              |                        | |               | |  |
|              v                        | |               | |  |
|       +------+-----+                  | |               | |  |
|       | All-Reduce |     Sync grads   | |   All-Reduce  | |  |
|       |            | <--------------> | |   (Dummy)     | |  |
|       +------+-----+                  | |               | |  |
|              |                        | +---------------+ |  |
|              |                        +-------------------+  |
|              v                                 |             |
|     +--------+-------+                         |             |
|     | Update Weights |                         |             |
|     |                |                         |             |
|     +--------+-------+                         |             |
|              |                                 |             |
|              |                                 |             |
+--------------+                                 +-------------+

5.2 使用

5.2.1 DistributedDataParallel

Join 可以和 DistributedDataParallel 一起使用,比如下面的例子之中,会启动两个worker,分别是 rank 0 和 rank 1,rank 0 会得到5个输入,rank 1会得到6个输入,这就是输入不均衡。

如果没有使用 Join,则 rank 1 会在处理第6个输入时候死掉挂起,因为rank 0没有相关输入,所以rank 1只能等待。如果使用了 Join,则不会出现这种问题,可以顺利结束。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将产生以下输出(其中来自 0 级和 1 级的 ranks,可以任意排序):

print
Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

5.2.2 ZeroRedundancyOptimizer

该上下文不仅是和一个类合作,也可以和多个类一起,比如PyTorch 的。

Join
ZeroRedundancyOptimizer
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与以前相同的输出。显着的变化是需要另外将实例传入 。

ZeroRedundancyOptimizer
Join()

后续会对等机制也进行分析。

ZeroRedundancyOptimizer

5.3 原理

在最新文档 https://pytorch.org/tutorials/advanced/generic_join.html 之中,PyTorch 给出了一定解释,我们翻译如下。

为了更好的使用,我们将介绍类以及支持类和。

Join
Joinable
JoinHook

备注:这部分在 v1.10.0 版本代码之中。

5.3.1 Joinable

Joinable

首先,与上下文管理器兼容的类必须继承抽象基类。特别的,必须实现:

Join
Joinable
Joinable
  • join_hook(self, **kwargs) -> JoinHook

这将返回 的实例,用来确定加入的进程应如何影响由 执行的每次迭代集体通信。

JoinHook
Joinable
Joinable
  • join_device(self) -> torch.device

这将返回上下文管理器用来执行集体通信的设备,例如或 。

Join
torch.device("cuda:0")
torch.device("cpu")
  • join_process_group(self) -> ProcessGroup

这将返回上下文管理器用于执行集体通信的进程组。

Join

概括一下,负责具体行为,join_device 和 join_process_group 负责具体集合通信。

JoinHook

需要注意的是,和是必需的属性,他们可以确保上下文管理器能够安排”加入”和”未加入”进程之间的集体通信。一种用法是使用 all-reduce 计算每次迭代中”未加入”进程的数量。另一种用法是实现 所需的机制,我们将在下面解释。

join_device
join_process_group
throw_on_early_termination=True

和已经继承并实现了上面的方法,这就是为什么我们可以在前面的例子中直接使用它们。

DistributedDataParallel
ZeroRedundancyOptimizer
Joinable
class DistributedDataParallel(Module, Joinable):

class ZeroRedundancyOptimizer(Optimizer, Joinable):

DDP 涉及到提供数据,所以继承Joinable可以理解, 为何也需要继承?这是因为 可以和 DDP 一起合作,并且 内部也有集合操作,所以需要被 Join 一起管理。

ZeroRedundancyOptimizer
ZeroRedundancyOptimizer
ZeroRedundancyOptimizer

类应该确保调用构造函数,因为它初始化了一个实例,上下文管理器在内部使用来确保正确性。将在每个的字段中保存。

Joinable
Joinable
JoinConfig
JoinConfig
JoinConfig
Joinable    
_join_config

5.3.2JoinHook

JoinHook

接下来,让我们分解一下类。提供了两个进入上下文管理器的入口点:

JoinHook
JoinHook
  • main_hook(self) -> None

当存在尚未加入(Join)的 rank 时,每个加入(Join)的 rank 都会重复调用此钩子。它目的是在每次训练迭代(例如,在一次前向传递,反向传递和优化器步骤)之中,隐藏由所执行的集体通信,即已经Join的rank 如何与未Join的rank执行集合通信。

Joinable
  • post_hook(self, is_last_joiner: bool) -> None

一旦所有 ranks 都加入,这个钩子就会被调用。它传递了一个额外的 参数,其表明此 rank 是否是最后加入的 rank 之一。该参数可能对同步有用。

bool
is_last_joiner

5.3.2.1 ZeroRedundancyOptimizer

我们以 内置的 main hook 来给出一个钩子的具体例子:因为加入的 rank 仍然负责更新和同步其参数分片,所以 main hook 依然执行优化器步骤。

ZeroRedundancyOptimizer
class _ZeROJoinHook(_JoinHook):
    def __init__(self, zero):
        assert isinstance(zero, ZeroRedundancyOptimizer), \
            "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " \
            "instance as the state"
        self.zero = zero
        super().__init__()

    def main_hook(self):
        """
        Performs an optimizer step, which updates the joined process's shard of
        the parameters and broadcasts those parameters.
        """
        self.zero.step()

step函数简略如下:

def step(
    self,
    closure: Optional[Callable[[], float]] = None,
    **kwargs: Any,
) -> Optional[float]:
    _Join.notify_join_context(self) # 这里会通知
    # Check if the model trainability has changed
    is_trainable_mask = self._get_is_trainable_mask()
    if is_trainable_mask != self._is_trainable_mask:
        self._build_param_buckets()
        self._is_trainable_mask = is_trainable_mask

    # Sync the exposed `param_groups` attributes to the local optimizer in
    # case they have been updated
    self._sync_param_groups(self.param_groups, self.optim.param_groups)

    # Run the optimizer step on this shard only
    if closure is not None:
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
    else:
        loss = self.optim.step(**kwargs)

    # Sync all of the updated parameter shards across the ranks
    self._sync_parameters()

    # Sync any updated attributes in the local optimizer to the exposed
    # `param_groups`
    self._sync_param_groups(self.optim.param_groups, self.param_groups)

    return loss

再来看看:

DistributedDataParallel
  • main_hook 依然会做相关的一系列操作来欺骗其他rank。
  • post-hook 会从最后加入的rank之一来广播最终更新的模型,以确保模型在所有rank中都是相同的。
class _DDPJoinHook(_JoinHook):
    def __init__(self, ddp, divide_by_initial_world_size):
        """
        Sets config variables for internal usage.
        """
        ddp.logger._set_uneven_input_join()
        self.ddp = ddp
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
        super().__init__()

    def main_hook(self):
        """
        Shadows the DDP collective communication operations in the forward and
        backward passes.
        """
        ddp = self.ddp
        # Buckets are rebuilt only once during a training period
        ddp.reducer._rebuild_buckets()

        # Schedule a broadcast if we are syncing module buffers in the
        # forward pass
        ddp._check_and_sync_module_buffers()

        # Check if need to sync in the backward pass
        work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
        work.wait()
        should_sync_backwards = work.result()[0].item() != 0
        # Forward parameter sync is disabled in the next iteration if we
        # are skipping gradient sync this iteration, so set
        # `require_forward_param_sync` accordingly
        ddp.require_forward_param_sync = should_sync_backwards
        if not should_sync_backwards:
            return

        # Schedule one allreduce per gradient bucket to match the backward
        # pass allreduce
        ddp._match_all_reduce_for_bwd_pass()

        # Check if we need to allreduce locally unused parameters
        if ddp.find_unused_parameters:
            ddp._match_unused_params_allreduce()

        # Rebuilt parameters are pushed only once during a training period
        ddp.reducer._push_all_rebuilt_params()

    def post_hook(self, is_last_joiner: bool):
        """
        Syncs the final model to ensure that the model is the same across all
        processes.
        """
        self.ddp._sync_final_model(is_last_joiner)

_sync_final_model 这里会广播最新的模型。

# When running in join model, agrees upon a common rank and broadcast model
# parameters to all other ranks.
def _sync_final_model(self, is_last_joiner):
    # Agree upon the process that will be the authoritative model copy.
    # The current rank is a candidate for being the authoritative copy if
    # is_last_joiner=True. We break ties via picking the larger rank.
    self._authoritative_rank = self._find_common_rank(
        self._distributed_rank, is_last_joiner
    )
    self._sync_params_and_buffers(authoritative_rank=self._authoritative_rank)

5.3.3 Join

Join

最后,让我们看看这些基础类是如何适应类本身的。

Join
  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的例子中看到的,构造函数接收一个参与训练循环的列表 。这些应该是在每次迭代中执行集体通信的类。

Joinable

是类型,如果您知道不会有不均匀的输入,则可以设置为 ,在这种情况下,上下文管理器变得类似于. 这也可能会在参与列表之中禁用join-related计算。

enable
bool
False
contextlib.nullcontext()
Joinable

是类型,其可以设置为,以便让每个等级在检测到不均匀输入时引发异常。这对于不符合上下文管理器要求的情况很有用,这通常是当来自不同类的集体通信可以任意交错(interleaved)时,例如与具有层的模型一起使用时 。在这种情况下,应将此参数设置为 以便应用程序逻辑可以捕获异常并确定如何继续。

throw_on_early_termination
bool
True
DistributedDataParallel
SyncBatchNorm
True
  • 核心逻辑出现在该__exit__()方法中,该方法在存在未加入的 rank 时会进行循环调用每个 Joinable的主钩子,然后一旦所有rank加入,就调用它们的 post 钩子。主钩子和后钩子都按照Joinables 传入的顺序进行迭代。
  • 上下文管理器需要来自未加入进程的心跳。因此,每个Joinable类都应该在每次迭代的集体通信之前调用Join.notify_join_context() 。上下文管理器将确保只有第一个传入的Joinable实际发送心跳。

5.4 例子

我们通过一个例子来具体看看。下面代码之中,每个rank会打印(1)在Join之前看到的所有rank的输入数量,以及(2)所有rank的输入总数。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    # 确定最后join的rank,由于后加入的rank可能不止一个,所以选择rank最大的rank来同步  
    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于rank 0看到5个输入,rank 1看到6个,因此产生输出:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

需要强调的一些要点:

  • Counter实例在每次迭代中执行一个all reduce操作,因此:

    对于已经Join的rank,其 main hook 也执行单个all reduce来对整体通信进行蒙骗操作( shadow it),注意这个 all-reduce是调用一个为0的tensor,所以对整体结果不影响。
    其他未 Join 的 rank 会以为这依然是一个正确的满员的集合操作。
    这样就处理了不均匀输入。

  • 对于已经Join的rank,其 main hook 也执行单个all reduce来对整体通信进行蒙骗操作( shadow it),注意这个 all-reduce是调用一个为0的tensor,所以对整体结果不影响。
  • 其他未 Join 的 rank 会以为这依然是一个正确的满员的集合操作。
  • 这样就处理了不均匀输入。
  • Counter类在其 __call__()方法的开头调用 Join.notify_join_context() ,因为这是每次集合操作(all-reduce)的地方,需要在这里通知上下文管理器,本示例还没有Join(已经结束的rank不会调用到这里)。
  • ‘is_last_joiner’参数用于确定post-hooks中的广播源。
  • 我们将 sync_max_count 关键字参数传递给上下文管理器,上下文管理器会将其转发给’Counter’的join hook。
  • post-hooks之中,会对 self.counter.max_count 进行广播。

0xFF 参考

pytorch分布式系列3——分布式训练时,torch.utils.data.distributed.DistributedSampler做了什么?

pytorch分布式系列1——搞清torch.distributed.launch相关的环境变量

pytorch分布式系列2——DistributedDataParallel是如何做同步的?

pytorch(分布式)数据并行个人实践总结——DataParallel/DistributedDataParallel

Pytorch的nn.DataParallel

https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/20

https://pytorch.org/docs/stable/distributed.html

PyTorch 源码解读之分布式训练了解一下?

实操教程|PyTorch AutoGrad C++层实现

PYTORCH 自动微分(一)

PyTorch如何加速数据并行训练?分布式秘籍大揭秘

pytorch分布式训练(二init_process_group)

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

https://pytorch.org/docs/master/notes/ddp.html

https://pytorch.org/tutorials/intermediate/dist_tuto.html

PyTorch 源码解读之 DP & DDP:模型并行和分布式训练解析

Pytorch模型中的parameter与buffer

【PyTorch开发者日 2020】PyTorch分布式数据并行(DDP)

[中文字幕] 深入理解 PyTorch 中的 Hook 机制

[中文字幕] 深入解读 Pytorch AutoGrad

DISTRIBUTED TRAINING WITH UNEVEN INPUTS USING THE JOIN CONTEXT MANAGER

谈谈torch1.10中的ZeroRedundancyOptimizer和Join

————————

[源码解析] PyTorch 分布式(11) —– DistributedDataParallel 之 构建Reducer和Join操作

  • [源码解析] PyTorch 分布式(11) —– DistributedDataParallel 之 构建Reducer和Join操作0x00 摘要0x01 引论1.1 调用1.2 参数说明0x02 Reducer 初始化2.1 构造函数2.2 初始化桶2.3 初始化视图2.3.1 BucketReplica成员变量2.3.2 调用2.4 初始化本地使用变量0x03 静态图3.1 缘由3.2 使用3.2 Reducer0x04 重建桶4.1 为何要重建4.2 准备重建4.3 重建4.4 何时设定重建4.5 直接调用0x05 Join5.1 缘起5.2 使用5.2.1 DistributedDataParallel5.2.2 ZeroRedundancyOptimizer5.3 原理5.3.1 Joinable5.3.2JoinHook5.3.2.1 ZeroRedundancyOptimizer5.3.3 Join5.4 例子0xFF 参考
  • 0x00 summary
  • 0x01 introduction 1.1 call 1.2 parameter description
  • 1.1 call
  • 1.2 parameter description
  • 0x02 reducer initialization 2.1 constructor 2.2 initialization bucket 2.3 initialization view 2.3.1 bucketreplica member variable 2.3.2 call 2.4 initialize local use variable
  • 2.1 constructor
  • 2.2 initialization bucket
  • 2.3 initialization view 2.3.1 bucket replica member variable 2.3.2 calling
  • 2.3.1 BucketReplica成员变量
  • 2.3.2 calling
  • 2.4 initializing local variables
  • 0x03 static Figure 3.1 reason for using 3.2 reducer
  • 3.1 reasons
  • 3.2 use
  • 3.2 Reducer
  • 0x04 rebuild bucket 4.1 why rebuild 4.2 prepare for rebuild 4.3 rebuild 4.4 when to set rebuild 4.5 direct call
  • 4.1 why rebuild
  • 4.2 preparation for reconstruction
  • 4.3 reconstruction
  • 4.4 when to set reconstruction
  • 4.5 direct call
  • 0x05 Join5.1 缘起5.2 使用5.2.1 DistributedDataParallel5.2.2 ZeroRedundancyOptimizer5.3 原理5.3.1 Joinable5.3.2JoinHook5.3.2.1 ZeroRedundancyOptimizer5.3.3 Join5.4 例子
  • 5.1 origin
  • 5.2 使用5.2.1 DistributedDataParallel5.2.2 ZeroRedundancyOptimizer
  • 5.2.1 DistributedDataParallel
  • 5.2.2 ZeroRedundancyOptimizer
  • 5.3 原理5.3.1 Joinable5.3.2JoinHook5.3.2.1 ZeroRedundancyOptimizer5.3.3 Join
  • 5.3.1 Joinable
  • 5.3.2JoinHook5.3.2.1 ZeroRedundancyOptimizer
  • 5.3.2.1 ZeroRedundancyOptimizer
  • 5.3.3 Join
  • 5.4 examples
  • 0xFF 参考

0x00 summary

Because the previous article has done correlation analysis around various member variables related to reducer, this paper begins to do dynamic logic analysis. The purpose is to connect the previous articles and set the basis for the subsequent analysis of forward propagation and back propagation.

Other articles in this series are as follows:

Automatic differentiation of deep learning tools (1)

Automatic differentiation of deep learning tools (2)

[source code analysis] automatic differentiation of deep learning tools (3) — example interpretation

[source code analysis] how pytorch implements forward propagation (1) — basic class (I)

[source code analysis] how pytorch implements forward propagation (2) — basic class (2)

[source code analysis] how pytorch implements forward propagation (3) — specific implementation

[source code analysis] how pytoch implements backward propagation (1) — call engine

[source code analysis] how pytoch implements backward propagation (2) — engine static structure

[source code analysis] how pytoch implements backward propagation (3) — engine dynamic logic

[source code analysis] how pytorch implements backward propagation (4) — specific algorithm

[source code analysis] pytorch distributed (1) — history and overview

[source code analysis] pytorch distributed (2) — dataparallel (Part 1)

[source code analysis] pytorch distributed (3) — dataparallel (Part 2)

[source code analysis] pytorch distributed (4) — basic concept of distributed application

[source code analysis] pytorch distributed (5) — distributeddataparallel Overview & amp; How to use

[source code analysis] pytorch distributed (6) — distributeddataparallel — initialization & amp; store

[source code analysis] pytorch distributed (7) — process group of distributeddataparallel

[source code analysis] pytorch distributed (8) — a paper on distributed dataparallel

[source code analysis] pytorch distributed (9) — initialization of distributeddataparallel

[source code analysis] pytorch distributed (10) — the reducer static architecture of distributeddataparallel

0x01 introduction

For better analysis, we still need to see how to call.

1.1 call

The creation code of reducer is as follows, which is in_ ddp_ init_ Helper.

        # Note: reverse list of buckets because we want to approximate the
        # order in which their gradients are produced, and assume they
        # are used in the forward pass in the order they are defined.
        self.reducer = dist.Reducer(
            parameters, # parameters[0]是张量列表
            list(reversed(bucket_indices)), # 桶信息
            self.process_group,
            expect_sparse_gradient,
            self.bucket_bytes_cap,
            self.find_unused_parameters,
            self.gradient_as_bucket_view,
            param_to_name_mapping,
        )

1.2 parameter description

The parameters called are as follows. Parameters [0] is the parameters of the model on rank 0. You can see that only [0] elements are meaningful. The original [0] itself includes 20 elements:

parameters = {list: 1} 
0 = {list: 4}           
 0 = {Parameter: 10} Parameter containing:\ntensor([[-4.0381e-02,  3.8828e-02, 1  )   
 1 = {Parameter: 10} Parameter containing:\ntensor([-0.0438, -0.2033,  0.2771,  0.0721,  ) 
 2 = {Parameter: 5} Parameter containing:\ntensor([[-0.0094, -0.1319,  0.0713,  0.3155,  )
 3 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )
 ...
 20 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )                                                   
 __len__ = {int} 20
__len__ = {int} 1

bucket_ Examples of indices are as follows:

For tensor indexes, we give all tensors an index, which increases from 0 to tensors. Size (). If the parameters of the model have a total of 20 tensors, then the tensor index is divided into six buckets from 0 to 19. Among the six buckets, each tensor index is unique and does not repeat.

+-----------------------------------------------------------------------+
|                                                                       |
|  <tensor index 0, tensor index 1, tensor index 2, tensor index 3>     |
|                                                                       |
|                                                                       |
|  <tensor index 4, tensor index 5, tensor 6>                           |
|                                                                       |
|                                                                       |
|  ......                                                               |
|                                                                       |
|                                                                       |
|  <tensor index 16, tensor index 17, tensor index 18, tensor index 19> |
|                                                                       |
+-----------------------------------------------------------------------+

Next, let’s see how to initialize reducer.

0x02 Reducer 初始化

代码位于:torch/lib/c10d/reducer.h 和 torch/lib/c10d/reducer.cpp

2.1 constructor

The specific logic is as follows:

  • See if this module is a multi device module. Specifically, traverse the tensor to obtain the tensor device, and insert the device into a set structure. If there is more than one device in the set, it is a multi device module
  • 如果 expect_sparse_gradients没有设置,就把expect_sparse_gradients_初始化为false。
  • Call initialize_ Buckets initializes buckets and allocates parameters to buckets in reverse order as much as possible, so that communication by bucket can improve efficiency. Subsequently, the bucket may be reinitialized again at run time.
  • 为每个 parameter 加上 grad_accumulator,它们在 backward 时负责梯度同步。

    因为这些variables是autograd图的叶子张量,所以它们的grad_fn都被设置为 gradient accumulation function。
    Reducer保存了指向这些functions的指针,这样Reducer就可以知道它们在autograd传播之中是否被使用,如果没有使用,那么就把这些functions的梯度张量(grad tensors)设置为规约就绪状态。
    遍历张量,为每个张量生成一个类型为VariableIndex的变量index。
    得到Variable::AutogradMeta的grad_accumulator_,即用于累加叶子 Variable 的梯度累加器。
    把reducer的autograd_hook函数添加进去每个grad_accumulator_之中,变量index是hook的参数。这个 hook 挂在 autograd graph 之上,在 backward 时负责梯度同步。grad_accumulator 执行完后,autograd_hook 就会运行。

  • 因为这些variables是autograd图的叶子张量,所以它们的grad_fn都被设置为 gradient accumulation function。
  • The reducer saves pointers to these functions so that the reducer can know whether they are used in autograd propagation. If not, set the gradient tensors of these functions to the protocol ready state.
  • Traverse tensors and generate a variable index of type variableindex for each tensor.
  • 得到Variable::AutogradMeta的grad_accumulator_,即用于累加叶子 Variable 的梯度累加器。
  • 把reducer的autograd_hook函数添加进去每个grad_accumulator_之中,变量index是hook的参数。这个 hook 挂在 autograd graph 之上,在 backward 时负责梯度同步。grad_accumulator 执行完后,autograd_hook 就会运行。
  • gradAccToVariableMap_ 存了grad_accumulator & index 的对应关系(函数指针和参数张量的对应关系),这样以后在 autograd graph 遍历寻找 unused parameters 就方便了。
  • 初始化 backward_stats_。
  • 调用 initialize_local_used_map 初始化各种 unused map。
// The constructor takes a list of variables for every model replica.
// The bucket assignment for this reducer is specified as a list of
// buckets, each of which is specified as a list of indices into the
// variables list for **a single replica** (i.e. `variables[0]`).
Reducer::Reducer(
    std::vector<std::vector<at::Tensor>> replicas, // 张量
    std::vector<std::vector<size_t>> bucket_indices, // 桶信息
    c10::intrusive_ptr<c10d::ProcessGroup> process_group,
    std::vector<std::vector<bool>> expect_sparse_gradients,
    int64_t bucket_bytes_cap,
    bool find_unused_parameters,
    bool gradient_as_bucket_view,
    std::unordered_map<size_t, std::string> paramNames)
    : replicas_(std::move(replicas)),
      process_group_(std::move(process_group)),
      expect_sparse_gradients_(std::move(expect_sparse_gradients)),
      expect_autograd_hooks_(false),
      require_finalize_(false),
      next_bucket_(0),
      has_marked_unused_parameters_(false),
      find_unused_parameters_(find_unused_parameters),
      gradient_as_bucket_view_(gradient_as_bucket_view),
      local_used_maps_reduced_(false),
      num_iterations_(0),
      num_buckets_ready_(0),
      has_rebuilt_bucket_(false),
      bucket_bytes_cap_(bucket_bytes_cap),
      divFactor_(kUnsetDivFactor),
      static_graph_(false),
      comm_hook_(nullptr),
      thread_local_state_(at::ThreadLocalState()),
      ddp_debug_level_(parseDistDebugLevel()),
      param_names_(std::move(paramNames)) {

  // Check whether the module is multi_device_module
  // 看看本模块是不是多设备模块
  {
    std::set<int> unique_devices;
    for (const auto& v : replicas_[0]) { // 遍历张量
      auto device_idx = int(v.device().index()); // 得到张量的设备
      if (unique_devices.find(device_idx) == unique_devices.end()) {
        unique_devices.insert(device_idx); // 把设备插入到一个set结构之中
        if (unique_devices.size() > 1) { // 如果set内的设备多于一个,是多设备
          is_multi_device_module_ = true; 
          break;
        }
      }
    }
  }

  // If `expect_sparse_gradients` is not specified, initialize it such that
  // we do not expect sparse gradients for any parameter.
  if (expect_sparse_gradients_.empty()) {
    expect_sparse_gradients_ = std::vector<std::vector<bool>>(
        replicas_.size(), std::vector<bool>(replicas_[0].size(), false));
  }

  // Initialize variable bucketing.
  // This can be reinitialized later after capturing runtime information.
  {
    std::lock_guard<std::mutex> lock(mutex_);
    initialize_buckets(std::move(bucket_indices)); //初始化桶
  }

  // All variables are expected to have their `grad_fn` set to the gradient
  // accumulation function (since they are leafs in the autograd graph).
  // We store pointers to these functions such that we can check if they are
  // used in an autograd pass. If they are not, we know their grad tensors
  // can be marked as ready for reduction.
  {
    const auto replica_count = replicas_.size();
    grad_accumulators_.resize(replica_count);
    for (size_t replica_index = 0; replica_index < replica_count; // 只有replicas_[0]有意义
         replica_index++) {
      const auto variable_count = replicas_[replica_index].size(); //张量数目
      grad_accumulators_[replica_index].resize(variable_count); // 给grad_accumulators_分配内存
        
      for (size_t variable_index = 0; variable_index < variable_count;
           variable_index++) { // 遍历张量,variable_index 就是张量的index
        auto& variable = replicas_[replica_index][variable_index]; //得到具体的张量
        const auto index = VariableIndex(replica_index, variable_index); //每个张量生成一个VariableIndex

        // The gradient accumulator function is lazily initialized once.
        // Therefore we can use its presence in the autograd graph as
        // evidence that the parameter has participated in an iteration.
        auto grad_accumulator =
            torch::autograd::impl::grad_accumulator(variable); // 得到Variable::AutogradMeta的grad_accumulator_,即,用于累加叶子 Variable 的梯度累加器

#ifndef _WIN32
        using torch::distributed::autograd::ThreadLocalDistAutogradContext;
#endif
        // Hook to execute after the gradient accumulator has executed.
        hooks_.emplace_back(
            // 累加器添加hook,这个 hook 挂在 autograd graph 之上,在 backward 时负责梯度同步。
            // grad_accumulator 执行完后,autograd_hook 就会运行
            grad_accumulator->add_post_hook(
                torch::make_unique<torch::autograd::utils::LambdaPostHook>(
                    [=](const torch::autograd::variable_list& outputs,
                        const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
                      this->rpc_context_.set(
                          ThreadLocalDistAutogradContext::getContextPtr());
#endif
                      this->autograd_hook(index); // 把reducer的autograd_hook函数添加进去
                      return outputs;
                    })),
            grad_accumulator);

        // Map raw function pointer to replica index and parameter index.
        // This is used later on when the autograd graph is traversed
        // to check for parameters for which no gradient is computed, if
        // find_unused_parameters=True.
        // Note that the mapping of gradient accumulator to variable should be
        // one to one as we deduplicate shared parameters before constructing
        // Reducer.
          
        // gradAccToVariableMap_ 存了grad_accumulator & index 的对应关系(函数指针和参数张量的对应关系),这样以后在 autograd graph 遍历寻找 unused parameters 就方便了
        if (find_unused_parameters_) {
          gradAccToVariableMap_[grad_accumulator.get()] = index;
        }

        numGradHooksTriggeredMap_[index] = 0;

        // The gradient accumulator is stored as weak_ptr in the autograd
        // metadata of the variable, so we have to keep it alive here for
        // the raw pointer to be valid.
        TORCH_CHECK(
            grad_accumulators_[replica_index][variable_index] == nullptr,
            c10::str(
                "Reducer tried to register duplicate grad accumulator for replica ",
                replica_index,
                " variable ",
                variable_index));
        grad_accumulators_[replica_index][variable_index] =
            std::move(grad_accumulator);
      }
    }
  }

  // Initialize backward stats vector.
  {
    const auto replica_count = replicas_.size();
    backward_stats_.resize(replica_count);
    const auto variable_count = replicas_[0].size();
    std::for_each(
        backward_stats_.begin(),
        backward_stats_.end(),
        [=](std::vector<int64_t>& v) { v.resize(variable_count); });
  }

  // See Note [Skip allreducing local_used_maps_dev]
  if (find_unused_parameters_) {
    initialize_local_used_map();
  }
}

Next, we analyze each part in detail.

2.2 initialization bucket

initialize_ The buckets method is used to initialize buckets. The specific logic is to add a model copy for each bucket and a tensor list for each model copy:

  • 用分布式上下文设置 rpc_context_。

    如果在DDP构造函数内调用initialize_bucket,则 rpc上下文指针(rpc context ptr)是否为null 无关紧要,因为grad不会发生变化。
    如果在训练循环期间调用initialize_bucket,例如在rebuild_bucket 内部,因为grad可能会发生改变并指向bucket_view,那么它需要检查rpc context ptr是否为null。
    如果rpc context ptr是null,则改变 variable.grad(),否则,在rpc上下文中改变梯度。

  • 如果在DDP构造函数内调用initialize_bucket,则 rpc上下文指针(rpc context ptr)是否为null 无关紧要,因为grad不会发生变化。
  • 如果在训练循环期间调用initialize_bucket,例如在rebuild_bucket 内部,因为grad可能会发生改变并指向bucket_view,那么它需要检查rpc context ptr是否为null。
  • 如果rpc context ptr是null,则改变 variable.grad(),否则,在rpc上下文中改变梯度。
  • 清空buckets_ 和 variable_locators_。
  • 重置variable_locators_的尺寸,这样每个variable都有一个bucket index。
  • 利用如下得到所有桶的个数和每个桶中副本个数:bucket_count = bucket_indices.size(); replica_count = replicas_.size();
  • 从0开始递增到 bucket_count,逐一初始化 Bucket。

    生成一个 Bucket bucket
    如果bucket_indices[bucket_index].size() == 1,说明这个桶期待一个single sparse gradient,则设置 bucket.expect_sparse_gradient = true。
    从0开始递增到replica_count,逐一初始化 BucketReplica。

    生成一个 BucketReplica replica
    如果这个桶期待一个single sparse gradient,则

    利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
    利用 variable_index 得到副本之中对应的variable。
    设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。

    否则说明是dense gradient,则

    遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
    设置variable的设备和数据类型
    给副本设置其variables,代码为:replica.variables.push_back(variable)。
    设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
    给relica.contents分配内存
    利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
    利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。

    遍历桶中的variable,代码为 bucket_indices[bucket_index]。

    设置 Reducer.variable_locators_,这样 Reducer 就知道如何在 bucket 之中确定一个varaible。bucket_index 是buckets_列表的位置,表示 buckets_ 之上的一个bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。

    设置桶的变量,bucket.variable_indices = std::move(bucket_indices[bucket_index]);
    利用 buckets_.push_back(std::move(bucket)) 把bucket这个桶加入到 Reducer之中。

  • 生成一个 Bucket bucket
  • 如果bucket_indices[bucket_index].size() == 1,说明这个桶期待一个single sparse gradient,则设置 bucket.expect_sparse_gradient = true。
  • 从0开始递增到replica_count,逐一初始化 BucketReplica。

    生成一个 BucketReplica replica
    如果这个桶期待一个single sparse gradient,则

    利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
    利用 variable_index 得到副本之中对应的variable。
    设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。

    否则说明是dense gradient,则

    遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
    设置variable的设备和数据类型
    给副本设置其variables,代码为:replica.variables.push_back(variable)。
    设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
    给relica.contents分配内存
    利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
    利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。

  • 生成一个 BucketReplica replica
  • 如果这个桶期待一个single sparse gradient,则

    利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
    利用 variable_index 得到副本之中对应的variable。
    设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。

  • 利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
  • 利用 variable_index 得到副本之中对应的variable。
  • 设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。
  • 否则说明是dense gradient,则

    遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
    设置variable的设备和数据类型
    给副本设置其variables,代码为:replica.variables.push_back(variable)。
    设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
    给relica.contents分配内存
    利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
    利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。

  • 遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
  • Set the device and data type of variable
  • 给副本设置其variables,代码为:replica.variables.push_back(variable)。
  • 设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
  • 给relica.contents分配内存
  • 利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
  • 利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。
  • 遍历桶中的variable,代码为 bucket_indices[bucket_index]。

    设置 Reducer.variable_locators_,这样 Reducer 就知道如何在 bucket 之中确定一个varaible。bucket_index 是buckets_列表的位置,表示 buckets_ 之上的一个bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。

  • 设置 Reducer.variable_locators_,这样 Reducer 就知道如何在 bucket 之中确定一个varaible。bucket_index 是buckets_列表的位置,表示 buckets_ 之上的一个bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。
  • 设置桶的变量,bucket.variable_indices = std::move(bucket_indices[bucket_index]);
  • 利用 buckets_.push_back(std::move(bucket)) 把bucket这个桶加入到 Reducer之中。

The specific code is:

void Reducer::initialize_buckets(
    std::vector<std::vector<size_t>> bucket_indices) {
  // If initialize_buckets is called inside DDP constructor, then
  // it does not matter rpc context ptr is nullptr or not, as grad
  // will not be mutated.
  // If initialize_buckets is called during training loop, e.g, inside
  // rebuild_buckets(), since grad could be mutated and be pointed to
  // bucket_view, then it needs to check rpc context ptr is nullptr or not,
  // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
  // mutate grad in rpc context.
#ifndef _WIN32
  using torch::distributed::autograd::ThreadLocalDistAutogradContext;
  this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
#endif

  // This shouldn't be called if we're expecting autograd hooks to fire.
  TORCH_CHECK(
      !expect_autograd_hooks_,
      "`initialize_buckets` must NOT be called during autograd execution.");

  // Clear current bucket assignment.
  buckets_.clear();
  variable_locators_.clear();

  // Ensure we have a bucket index for every variable.
  variable_locators_.resize(replicas_[0].size());

  // Iterate over buckets.
  const auto bucket_count = bucket_indices.size();
  const auto replica_count = replicas_.size();
  buckets_.reserve(bucket_count);
  // 从0开始递增到bucket_count
  for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
    Bucket bucket; // 生成一个桶

    // TODO(@pietern): Validate indices.
    // Must be non-empty, unique, and unique across buckets.
    TORCH_CHECK(
        bucket_indices[bucket_index].size() > 0, "Empty bucket specified.");

    // Variables that expect sparse gradients must have their own bucket.
    if (bucket_indices[bucket_index].size() == 1) {
      // 说明这个桶期待一个single sparse gradient
      const auto variable_index = bucket_indices[bucket_index].front();
      bucket.expect_sparse_gradient =
          expect_sparse_gradients_[0][variable_index];
    } else {
      for (const auto variable_index : bucket_indices[bucket_index]) {
        TORCH_CHECK(
            !expect_sparse_gradients_[0][variable_index],
            "Buckets with more than one variable cannot include variables ",
            "that expect a sparse gradient.");
      }
    }

    // Iterate over model replicas. 从0开始递增到replica_count,遍历模型副本数目,为每一个模型副本都要做同样设置
    for (size_t replica_index = 0; replica_index < replica_count;
         replica_index++) {
      BucketReplica replica; // 生成一个副本

      if (bucket.expect_sparse_gradient) {
        // 说明这个桶期待一个single sparse gradient
        const auto variable_index = bucket_indices[bucket_index].front(); // 得到张量的index
        const auto& variable = replicas_[replica_index][variable_index]; // 得到张量
        TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1);
        replica.variables = {variable}; // 这个副本只包括一个variable
      } else {
        at::TensorOptions options;
        // The start index of the variable in the flattened tensor.
        size_t offset = 0;

        // Reserve enough space for the per-variable fields stored in bucket
        // replica for efficiency.
        const size_t num_variables = bucket_indices[bucket_index].size();
        replica.variables.reserve(num_variables); 
        replica.offsets.reserve(num_variables);
        replica.lengths.reserve(num_variables);
        replica.sizes_vec.reserve(num_variables);

        // Iterate over bucket variables.
        for (const auto variable_index : bucket_indices[bucket_index]) { //遍历桶中的variable
          TORCH_CHECK(
              variable_index < replicas_[replica_index].size(),
              "Out of range variable index specified.");
          const auto& variable = replicas_[replica_index][variable_index];
          if (!options.has_device()) {
            options = options.device(variable.device());
          } else {
            TORCH_CHECK(
                variable.device() == options.device(),
                "All parameters in a bucket must be ",
                "placed on the same device.");
          }
          if (!options.has_dtype()) {
            options = options.dtype(variable.dtype());
          } else {
            TORCH_CHECK(
                variable.dtype() == options.dtype(),
                "All parameters in a bucket must have the same dtype.");
          }
          
          const auto length = variable.numel();
          // 给副本设置其variables
          replica.variables.push_back(variable); // 这里添加了一个新变量,所以最终能知道该桶中的变量数目
          // 设置replica 的一些关于variable的元信息
          replica.offsets.push_back(offset);
          replica.lengths.push_back(length);
          replica.sizes_vec.push_back(variable.sizes());
          offset += length;
        }

        // Allocate bucket contents tensor.
        replica.contents = at::empty({static_cast<long>(offset)}, options);

        initialize_bucket_views(replica, replica.contents); // 初始化cotents和views
      }

      // Add bucket replica to enclosing bucket.
      bucket.replicas.push_back(std::move(replica)); // 桶的副本列表中添加一个新副本
    }

    // Map participating variables to this bucket.
    // This is identical across replicas so we only need to do this once.
    size_t intra_bucket_index = 0;
    for (const auto variable_index : bucket_indices[bucket_index]) { // 遍历桶中的variable
      TORCH_CHECK(
          variable_index < variable_locators_.size(),
          "Out of range variable index specified.");
      variable_locators_[variable_index] = // 这样 Reducer 就知道如何在 bucket 之中确定一个varaible
          VariableLocator(bucket_index, intra_bucket_index++);
    }
    bucket.variable_indices = std::move(bucket_indices[bucket_index]);

    buckets_.push_back(std::move(bucket)); // 把桶插入Reducer
  }
}

2.3 initialization view

initialize_bucket_views 这里是设置 Replica 的contents 和 views。

// (see Note:  "Gradient Layout Contract" in initialize_buckets).
void Reducer::initialize_bucket_views(
    Reducer::BucketReplica& replica,
    at::Tensor& contents) {
  for (size_t i = 0; i < replica.variables.size(); i++) {
    auto& v = replica.variables[i];
    const auto offset = replica.offsets[i];
    const auto length = replica.lengths[i];
    if (v.is_non_overlapping_and_dense()) { // Dense类型的张量
      // If the param's memory is dense, match its layout, anticipating
      // the autograd engine (AccumulateGrad) will also create gradients
      // matching its layout.
      replica.bucket_views_in.push_back( // replica.bucket_views_in里面都是视图
          contents.as_strided(v.sizes(), v.strides(), offset));
    } else { // Sparse类型的张量
      // Fall back to a C-style contiguous view, again anticipating
      // AccumulateGrad will do the same when stashing grads for non-dense
      // params.
      replica.bucket_views_in.push_back( // replica.bucket_views_in里面都是视图
          contents.narrow(0, offset, length).view(v.sizes()));
    }
    // By default `bucket_views_out` and `bucket_views_in` are
    // essentially the same thing.
    replica.bucket_views_out = replica.bucket_views_in; // out也是视图

    // If gradient_as_bucket_view_ is set as true, then there are two cases to
    // handle: initialize_bucket_views could be called inside initialize_buckets
    // when rebuild_buckets, if grad has already been defined/calculated in
    // previous iteration, old grad needs to be copied into new bucket_view and
    // let grad point to the new bucket_view, initialize_bucket_views could also
    // be called inside initialize_buckets during construction. Grads are not
    // defined during construction time, in this case, do not let grad point to
    // bucket_view, because grads should be kept as being undefined for globally
    // unused parameters.
    if (gradient_as_bucket_view_) {
      auto& bucket_view = replica.bucket_views_in.back();
      runGradCallbackForVariable(v, [&](auto& grad) {
        if (grad.defined() && !grad.is_alias_of(bucket_view)) {
          bucket_view.copy_(grad);
          grad = bucket_view; // 梯度被修改了,需要回写
          // The grad is modefied and needs to be written back.
          return true;
        }
        // The grad is not modified and does not need to be written back.
        return false; // 不需要回写,因为没有被修改
      });
    }
  }
}

2.3.1 BucketReplica成员变量

Let’s first recall several member variables of bucketreplica.

  • at::Tensor contents :把桶的内容展平的结果,即Flattened (1 dimensional) 之后的结果。
  • std::vector bucket_views_in :提供了从输入角度在 contents 之中查看具体梯度的方法。
  • std::vector bucket_views_out :提供了从输入角度在 contents 之中查看具体梯度的方法。

Further notes on and:

std::vector<at::Tensor> bucket_views_in
std::vector<at::Tensor> bucket_views_out
  • These two variables provide methods to manipulate specific gradients in contents, or they provide views that can manipulate the gradients of each tensor in contents. Users use these two variables as entry points to move the data of each gradient in and out of the content.
  • In pytorch, view refers to creating something convenient to view. The view shares memory with the original data. It just arranges the original data, directly displays some of its contents, or displays it after reordering.

Several pytorch functions also need to be described.

  • as_ Striped: create a view according to the existing tensor and the given step size (the type is still tensor). Note that the result here is a view, so this tensor still shares memory with the original tensor.
  • Narrow: returns a new tensor, which is a reduced version of the original tensor, but this tensor still shares memory with the original tensor.

Bucket replica logic is shown in the following figure:

+------------------------------------------+
| BucketReplica                            |
|                                          |
|       vector<Tensor> bucket_views_in +--------------------+
|                                          |                |
|                                          |                |
|       vector<Tensor> bucket_views_out +--------------+    |
|                                          |           |    |
|                                          |           |    |
|                                          |           v    v
|                                          |     +-----+----+--------------------------+
|       Tensor contents  +---------------------> |Flattened (Tensor1, Tensor2, Tensor3)|
|                                          |     +-------------------------------------+
|                                          |
|                                          |
|       vector<Tensor> variables  +------------>  [Tensor1,Tensor2,Tensor3]
|                                          |
|                                          |
|                                          |
+------------------------------------------+

2.3.2 calling

How to call? If set to true, two situations need to be handled:

gradient_as_bucket_view_
  • rebuild_buckets 之中可以在initialize_bucket内调用initialize_bucket_view,如果grad在上一次迭代中已经定义/计算过,则需要将旧的grad复制到新的bucket_view中,并让grad指向新的bucket_view,
  • During construction, you can also initialize_ Calling initialize_ in bucket bucket_ views。 Gradients are not defined during construction. In this case, do not let the gradients point to buckets_ View, because for parameters not used globally, the gradient should remain undefined.

2.4 initializing local variables

initialize_ local_ used_ Map here is initialization. We recall the content of this paper, which is used to find global unused parameters:

local_used_maps_
local_used_maps_

The gradient of global unused parameters should remain unchanged in the forward and backward process. Detecting unused parameters requires global information, because in a DDP process, a parameter may not exist in one operation, but may participate in training in the same iteration of another process. Therefore, DDP maintains locally unused parameter information in the bitmap and starts additional allreduce to collect the global bitmap. Since the bitmap is much smaller than the tensor size, all parameters in the model share the same bitmap instead of creating per bucket bitmaps. The bitmap is located on the CPU to avoid starting a dedicated CUDA kernel for each update. However, some processgroup backend may not be able to run allreduce on the CPU tensor. For example, processgroupnccl only supports CUDA tensors. In addition, since DDP should work with any custom processgroup backend, it cannot assume that all backend supports CPU tensor. To solve this problem, DDP maintains another bitmap on the same device as the first model parameter, and calls a non blocking copy to move the CPU bitmap to the device bitmap for collective communication.

The gradient of global unused parameters should remain unchanged in the forward and backward process. Detecting unused parameters requires global information, because in a DDP process, a parameter may not exist in one operation, but may participate in training in the same iteration of another process. Therefore, DDP maintains locally unused parameter information in the bitmap and starts additional allreduce to collect the global bitmap. Since the bitmap is much smaller than the tensor size, all parameters in the model share the same bitmap instead of creating per bucket bitmaps. The bitmap is located on the CPU to avoid starting a dedicated CUDA kernel for each update. However, some processgroup backend may not be able to run allreduce on the CPU tensor. For example, processgroupnccl only supports CUDA tensors. In addition, since DDP should work with any custom processgroup backend, it cannot assume that all backend supports CPU tensor. To solve this problem, DDP maintains another bitmap on the same device as the first model parameter, and calls a non blocking copy to move the CPU bitmap to the device bitmap for collective communication.

The specific codes are as follows:

void Reducer::initialize_local_used_map() {
  const auto replica_count = replicas_.size();
  const auto variable_count = replicas_[0].size();
  local_used_maps_.resize(replica_count);
  local_used_maps_dev_.resize(replica_count);

  for (size_t i = 0; i < replica_count; i++) {
    at::TensorOptions options;
    options = options.dtype(at::kInt);

    // Deliberately don't pin the memory even if local_used_maps_dev_ will
    // be cuda. See Note [local_used_maps_ -> local_used_maps_dev copying]
    local_used_maps_[i] =
        at::zeros({static_cast<long>(variable_count)}, options);

    // This tensor needs to be on the same device as replica because backend
    // such as NCCL may not support CPU tensors, and hence it might not work
    // if we always put it on CPU.
    options = options.device(replicas_[i][0].device());
    local_used_maps_dev_[i] =
        at::empty({static_cast<long>(variable_count)}, options);
  }
}

The initialization process is as follows:

                                    +
                                    |
                                    |
                                    v
                  rpc_context_ = ThreadLocalDistAutogradContext
                                    +
                                    |
                                    |
                                    v
                  buckets_ & variable_locators_ (clear & resize)
                                    +
                                    |
                                    |
                                    v
+----------------------->  from 0 ~ bucket_count :  +--------------------------->
|                                                                                +
|                                                                                |
|      +-------------------------------------------------------------------+     |
|      | init Bucket          set bucket_indices                           |     |
|      |                            +                                      |     |
|      |                            |                                      |     |
|      |                            |                                      |     |
|      |                            v                                      |     |
|      |   ^ +------------> from 0 ~ replica_count : +----------------->   |     |
|      |   |                                                           |   |     |
|      |   |  +---------------------------------------------------+    |   |     |
|      |   |  | init BucketReplica                                |    |   |     |
|      |   |  |                                                   |    |   |     |
<----+ |   +--+                                                   | <--+   | <---+
       |      |    bucket.replicas.push_back(std::move(replica))  |        |
       |      |                                                   |        |
       |      +----------------------+----------------------------+        |
       |                             |                                     |
       |                             |                                     |
       |                             v                                     |
       |             buckets_.push_back(std::move(bucket))                 |
       |                             +                                     |
       +-------------------------------------------------------------------+
                                     |
                                     v

The reducers obtained are roughly as follows. It should be noted that there is only one bucket in bucketreplica:

            +----------------------------------------+                 +------------------+
            |tensor index 4, tensor index 5, tensor 6| <------+        | index 2, index 3 |
            +----------------------------------------+        |        +--------------+---+
                                                              |                       ^
                                                              |                       |
+---------------------------+   +---------------------------------------------------------+
| Reducer                   |   | +----------------------------------+     +------------+ |
|                           |   | |Bucket                     |      |     |Bucket    | | |
|                           |   | |                           +      |     |          | | |
| vector<Bucket> buckets_ +---> | | vector<size_t> variable_indices  |     | indices ++ | |
|                           |   | |                                  |     |            | |
|                           |   | |  vector<BucketReplica> replicas  | ... | replicas   | |
|                           |   | |                         +        |     |   +        | |
|                           |   | |                         |        |     |   |        | |
|                           |   | +----------------------------------+     +------------+ |
|                           |   |                           |                  |          |
+---------------------------+   +---------------------------------------------------------+
                                                            |                  |
                                                            |                  |
                                                            v                  v
                          +---------------------------------------+   +-------------------+
                          |  +----------------------------------+ |   | +---------------+ |
                          |  | BucketReplica                    | |   | | BucketReplica | |
                          |  |                                  | |   | |               | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_in  | |   | |   views_in    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_out | |   | |   views_out   | |
                          |  |                                  | |   | |               | |
                          |  |  Tensor contents                 | |   | |   contents    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> variables        | |   | |   variables   | |
                          |  |                     +            | |   | |      +        | |
                          |  +----------------------------------+ |   | +---------------+ |
                          +---------------------------------------+   +-------------------+
                                                   |                           |
                                                   |                           |
                                                   v                           v
                                   +---------------+------------+    +---------+----------+
                                   |Tensor 4, Tensor 5, Tensor 6|    | Tensor 2, Tensor 3 |
                                   +----------------------------+    +--------------------+

0x03 static diagram

3.1 reasons

Although pytorch is a dynamic graph, the user can clearly let DDP know that the training graph is static. It can be set in the following cases:

  • Used and unused parameter sets remain unchanged throughout the training cycle. In this case, will the user find_ unsued_ Setting parameters to true is not important.
  • The training mode of graphics will not change during the whole training cycle (which means that there is no control flow dependent on iteration). When the graph is set to static, DDP will support cases that were not previously supported, such as:
    Reentrant back propagation.
    Multiple activation checkpointing.
    Activation checkpointing and find_ unused_ parameters = true。
    Not all output tensors are used for loss calculation..
    There is a model parameter outside the forward function.
    When find_ unsued_ When parameters = true or there are unused parameters, performance may be improved because DDP does not search the network to check unused parameters within each iteration.
  • Reentrant back propagation.
  • 多次activation checkpointing。
  • activation checkpointing 并且find_unused_parameters = true。
  • Not all output tensors are used for loss calculation..
  • There is a model parameter outside the forward function.
  • When find_ unsued_ When parameters = true or there are unused parameters, performance may be improved because DDP does not search the network to check unused parameters within each iteration.

3.2 use

_ set_ static_ Graph can configure static diagrams. This API should be constructed after DistributedDataParallel and is called before the training cycle starts. Also, all rank should be called in the same way. For example:

ddp_model = DistributedDataParallel(model)
ddp_model._set_static_graph()
for i in range(n):

_ set_ static_ The graph code is:

def _set_static_graph(self):
    """
    Users can explicitly let DDP know the trained graph is static,
    when 1) the set of used and unused parameters will not change
    during the whole training loop; in this case, it does not matter
    whether users set find_unsued_parameters = true or not.
    2) how the graph is trained will not change during the whole training
    loop (meaning there is no control flow depending on iterations).
    When graph is set to be static, DDP will support cases that can not
    be supported in the past: 1) reentrant backwards
    2) activation checkpointing multiple times 3)
    activation checkpointing with find_unused_parameters = true.
    4) not all output tensors are used in loss calculation.
    5) there is model parameter that is outside of forward function.
    6) potentially improve performance when find_unsued_parameters = true
    or there are unused parameters, as DDP will not search graph in each
    iteraton to detect unused parameters when static_graph is set to be True.

    This API should be called after DistributedDataParallel construction, and
    before training loops starts. Also it should be called in the same way for
    all ranks. For example:
        ddp_model = DistributedDataParallel(model)
        ddp_model._set_static_graph()
        for i in range(n):
            .....
    """
    self.static_graph = True
    self.reducer._set_static_graph() # 调用 Reducer 进行配置
    self.logger._set_static_graph()
    if self.find_unused_parameters:
        warnings.warn(
            "You passed find_unused_parameters=true to DistributedDataParallel, "
            "`_set_static_graph` will detect unused parameters automatically, so "
            "you do not need to set find_unused_parameters=true, just be sure these "
            "unused parameters will not change during training loop while calling "
            "`_set_static_graph`."
        )

3.2 Reducer

Reducer can only generate static graphs after the first iteration, because pytorch is still dynamic after all, and you have to take one step to generate dynamically anyway.

void Reducer::set_static_graph() {
  std::lock_guard<std::mutex> lock(mutex_);
  TORCH_CHECK(
      num_iterations_ == 0,
      "set_static_graph() should be called before training loop starts "
      "and after DistributedDataParallel is constructed.");
  static_graph_ = true;
  // when static_graph_ is set as true, always initialize_local_used_map
  // and detect the global unused parameters in the first iteration.
  initialize_local_used_map();
}

0x04 rebuild bucket

4.1 why rebuild

Since pytorch is a calculation diagram generated dynamically, it is necessary to reconstruct the bucket accordingly. However, the static graph can only be rebuilt after the first iteration if find is set_ unused_ parameters_, No reconstruction.

  // Returns true if we should rebuild buckets, else false. We only rebuild
  // buckets once after the first iteration and never rebuild them if
  // find_unused_parameters_.
  inline bool should_rebuild_buckets() const {
    return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;
  }

4.2 preparation for reconstruction

Let’s first look at some preparations before reconstruction.

push_ rebuilt_ Params is to insert a reconstruction parameter list.

void Reducer::push_rebuilt_params(const VariableIndex& index) {
  rebuilt_params_.push_back(
      replicas_[index.replica_index][index.variable_index]);
  rebuilt_param_indices_.push_back(index.variable_index);
}

Second, push_ rebuilt_ params_ for_ all_ Indexes will traverse each replica and set each variable in the replica.

void Reducer::push_rebuilt_params_for_all_indices() {
  std::lock_guard<std::mutex> lock(mutex_);
  if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
    return;
  }
  const auto replica_count = replicas_.size();
  for (size_t replica_index = 0; replica_index < replica_count;
       ++replica_index) {
    const auto variable_count = replicas_[replica_index].size();
    for (size_t variable_index = 0; variable_index < variable_count;
         ++variable_index) {
      const auto index = VariableIndex(replica_index, variable_index);
      push_rebuilt_params(index);
    }
  }
}

4.3 reconstruction

Let’s look at the reconstruction mechanism.

DDP uses rebuild according to the time when the tensor receives the gradient in backward propagation_ params_ And rebuild_ param_ indices_ To rebuild the bucket.

rebuild_ The buckets function makes broadcast communication calls and can overlap with the next forward () call, so it can be asynchronous.

  • In find_ unused_ When parameters = true, rebuilding a bucket is an asynchronous operation, because we can rebuild a bucket many times. The subgraph is trained, and the parameter index order may change more frequently.
  • 对于find_unused_parameters=false的情况,bucket只重建一次,性能成本可以忽略不计。如果已重建存储桶, rebuild_buckets 则返回true。
bool Reducer::rebuild_buckets() {
  // Ensure reduction for previous backwards pass is finished. If user's model
  // has unused parameters for example, this will raise an error recommending to
  // run with find_unused_parameters=True, instead of the size mismatch
  // exception below.
  std::lock_guard<std::mutex> lock(mutex_);
  ensure_prior_reduction_finished();
  if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
    return false;
  }

  std::vector<std::vector<size_t>> rebuilt_bucket_indices;
  std::vector<size_t> bucket_size_limits;
  bucket_size_limits.push_back(kDefaultFirstBucketBytes);
  bucket_size_limits.push_back(bucket_bytes_cap_);
  rebuilt_bucket_indices = compute_bucket_assignment_by_size(
      rebuilt_params_,
      bucket_size_limits,
      expect_sparse_gradients_[0],
      rebuilt_param_indices_);

  // For rebuilt bucket indices, it needs to be synced across all ranks.
  // Broadcast the newly rebuilt bucket indices from rank 0 in default.
  // After syncing up rebuilt bucket indices, initialize buckets for reducer.
  sync_bucket_indices(rebuilt_bucket_indices);

  has_rebuilt_bucket_ = true; // 只重建一次
  rebuilt_params_.clear();
  rebuilt_param_indices_.clear();

  initialize_buckets(std::move(rebuilt_bucket_indices));
  return true;
}

4.4 when to set reconstruction

Reconstruction can only be set in the following cases:

  • Rebuild bucket for the first time
  • static_graph_ is true 或 find_unused_parameters_ is false
  • This back propagation process requires allreduce to be run.

Here, we just dump the tensor and its parameter index to and based on the gradient arrival order. Then in finalize_ At the end of backward (), the bucket is rebuilt based on and, and then broadcast and initialize the bucket.

rebuilt_params_
rebuilt_param_indices_
rebuilt_params_
rebuilt_param_indices_

In addition, we only need to dump a copy of the tensor and parameter index.

Mark_ variable_ Take ready as an example, where push will be called_ rebuilt_ Params (index) to insert the list.

void Reducer::mark_variable_ready(VariableIndex index) {
  // Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
  // static_graph_ is true or find_unused_parameters_ is false,
  // 3) this backward pass needs to run allreduce.
  // Here, we just dump tensors and their parameter indices into
  // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
  // order, and then at the end of finalize_backward(), buckets will be
  // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
  // will be broadcasted and initialized. Also we only need to dump tensors
  // and parameter indices of one replica.
  if (should_rebuild_buckets()) {
    push_rebuilt_params(index); // 插入列表
  }

  const auto replica_index = index.replica_index;
  const auto variable_index = index.variable_index;

  if (replica_index == 0) {
    checkAndRaiseMarkedTwiceError(variable_index);
    perIterationReadyParams_.insert(variable_index);
  }
  backward_stats_[replica_index][variable_index] =
      current_time_in_nanos() - cpu_timer_.backward_compute_start_time;

  // Any time we mark a variable ready (be it in line due to unused parameters,
  // or via an autograd hook), we require a call to the finalize function. If
  // this doesn't happen before the next iteration (or call to
  // `prepare_for_backwards`), we know something is wrong.
  require_finalize_ = true;

  const auto& bucket_index = variable_locators_[variable_index];
  auto& bucket = buckets_[bucket_index.bucket_index];
  auto& replica = bucket.replicas[replica_index];

  set_divide_factor();

  if (bucket.expect_sparse_gradient) {
    mark_variable_ready_sparse(index);
  } else {
    mark_variable_ready_dense(index);
  }

  // TODO(@pietern): Make this work for both CPU/CUDA tensors.
  // When using CPU tensors we don't need to do this.
  // // Record event so that we can wait for all of them.
  // auto& event = replica.events[bucket_index.intra_bucket_index];
  // event.record();

  // Check if this was the final gradient for this bucket.
  if (--replica.pending == 0) {
    // Kick off reduction if all replicas for this bucket are ready.
    if (--bucket.pending == 0) {
      mark_bucket_ready(bucket_index.bucket_index);
    }
  }

  // Run finalizer function and kick off reduction for local_used_maps once the
  // final bucket was marked ready.
  if (next_bucket_ == buckets_.size()) {

    if (dynamic_graph_find_unused()) {
      all_reduce_local_used_map();
    }

    // The autograd engine uses the default stream when running callbacks, so we
    // pass in the current CUDA stream in case it is not the default.
    const c10::Stream currentStream = get_current_stream();
    torch::autograd::Engine::get_default_engine().queue_callback([=] {
      std::lock_guard<std::mutex> lock(this->mutex_);
      // Run callback with the current stream
      c10::OptionalStreamGuard currentStreamGuard{currentStream};
      if (should_collect_runtime_stats()) {
        record_backward_compute_end_time();
      }
      // Check that all buckets were completed and had their work kicked off.
      TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
      this->finalize_backward();
    });
  }
}

4.5 direct call

_ rebuild_ The buckets function can also be called directly. For example, in the following case, forward is called once during the whole training period.

def forward(self, *inputs, **kwargs):
    with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
        self.reducer.save_thread_local_state()
        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            self.num_iterations += 1
            self.reducer.prepare_for_forward()
        if self.ddp_uneven_inputs_config.ddp_join_enabled:
            ones = torch.ones(1, device=self.device)
            work = dist.all_reduce(ones, group=self.process_group, async_op=True)
            if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                # Active ranks schedule an allreduce with zeros, inactive
                # ranks schedule them with 1. If the result != 0 it
                # indicates at least one rank has terminated and we should
                # throw.
                zeros = torch.zeros(1, device=self.device)
                dist.all_reduce(zeros, group=self.process_group)
                should_throw_stop_iteration = zeros.item()
                if should_throw_stop_iteration:
                    raise RuntimeError(
                        "Detected at least one rank that exhausted inputs. Throwing across all ranks."
                    )
            else:
                self.reducer._set_forward_pass_work_handle(
                    work,
                    self.ddp_uneven_inputs_config.ddp_join_divide_by_initial_world_size,
                )

        # Calling _rebuild_buckets before forward compuation,
        # It may allocate new buckets before deallocating old buckets
        # inside _rebuild_buckets. To save peak memory usage,
        # call _rebuild_buckets before the peak memory usage increases
        # during forward computation.
        # This should be called only once during whole training period.
        
        # 在这里进行直接调用
        if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): # 设定
            logging.info("Reducer buckets have been rebuilt in this iteration.")

For another example, the join method can also be called directly for reconstruction.

@contextmanager
def join(
    self,
    divide_by_initial_world_size=True,
    enable=True,
    throw_on_early_termination=False,
):
  
  									# 忽略其他代码
    
                    else:
                        # Some DDP process still needs to be joined.
                        if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                            # Schedule allreduce telling active ranks to terminate
                            ones = torch.ones(1, device=self.device)
                            dist.all_reduce(ones, group=self.process_group)
                            # Raising StopIteration doesn't throw error in python 3.6
                            # and throws RuntimeError in 3.7+ (PEP 479), so just
                            # raise RuntimeError here.
                            raise RuntimeError(
                                f"Rank {self._distributed_rank} exhausted all inputs."
                            )
                        if is_last_joiner:
                            is_last_joiner = False
                        # It will rebuild buckets only once during training period
                        
                        # 这里进行调用。
                        self.reducer._rebuild_buckets()
                        # Schedule a corresponding broadcast if we are syncing module
                        # buffers in the forward pass.
                        self._check_and_sync_module_buffers()   

Now that we have mentioned join, let’s take a look at this concept.

0x05 Join

Join is to solve the problem of uneven training data, that is, to allow some workers with less input (who have completed the join operation) to continue to perform collective communication with those workers that have not finished, which is a spoofing operation (shadow).

5.1 origin

Behind DDP is the all reduce operation of several collective communication libraries, which completes the gradient synchronization between workers. When the input of training data between ranges is unevenly, DDP will hang. Because collective communication requires all the ranges in the process group to participate, if one range has few inputs, other ranges will hang or report errors (depending on the back end), and any class will encounter this problem in each iteration when performing synchronous collective communication.

Therefore, DDP provides a “join” API, which is a context manager used in the training cycle of each rank. A rank with a small amount of data will exhaust the input in advance. At this time, it will give an illusion to the collective communication, so as to build a dummy all reduce to match with other ranks when the data is insufficient. How to create this illusion is specified by the registered hook.

Join

The general idea is as follows:

                +----------------------------+
                |             Data           |
                |   +--------+   +--------+  |
                |   |        |   | Empty  |  |
                |   |        |   |        |  |
                |   +-----+--+   +--------+  |
                |         |                  |
                |         |                  |
                +----------------------------+
                          |
                          |
        +------------+    |               +------------+
        |            |    |               |            |
+---->  |    Model   |    |               |   Model    | <-----+
|       |            |    |               |            |       |
|       +------+-----+    |               +------+-----+       |
|              |          |                      |             |
|              |          |                      |             |
|              v          |                      v             |
|       +------+-----+    |             +--------+----------+  |
|       |  Forward   +<---+             | _JoinHook         |  |
|       |  (local)   |                  |                   |  |
|       +------+-----+                  |                   |  |
|              |                        |                   |  |
|              |                        |                   |  |
|              v                        | +---------------+ |  |
|       +------+-----+                  | | main_hook     | |  |
|       |  Backward  |                  | |               | |  |
|       |  (local)   |                  | |               | |  |
|       +------+-----+                  | |               | |  |
|              |                        | |               | |  |
|              |                        | |               | |  |
|              v                        | |               | |  |
|       +------+-----+                  | |               | |  |
|       | All-Reduce |     Sync grads   | |   All-Reduce  | |  |
|       |            | <--------------> | |   (Dummy)     | |  |
|       +------+-----+                  | |               | |  |
|              |                        | +---------------+ |  |
|              |                        +-------------------+  |
|              v                                 |             |
|     +--------+-------+                         |             |
|     | Update Weights |                         |             |
|     |                |                         |             |
|     +--------+-------+                         |             |
|              |                                 |             |
|              |                                 |             |
+--------------+                                 +-------------+

5.2 use

5.2.1 DistributedDataParallel

Join can be used together with distributeddataparallel. For example, in the following example, two workers will be started, namely rank 0 and rank 1. Rank 0 will get 5 inputs and rank 1 will get 6 inputs, which is input imbalance.

If join is not used, rank 1 will die and hang when processing the sixth input. Because rank 0 has no relevant input, rank 1 can only wait. If join is used, this problem will not occur and can be ended smoothly.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

This will produce the following output (where ranges from level 0 and level 1 can be sorted arbitrarily):

print
Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

5.2.2 ZeroRedundancyOptimizer

This context not only cooperates with one class, but also with multiple classes, such as pytorch.

Join
ZeroRedundancyOptimizer
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

This will produce the same output as before. The significant change is the need to pass in additional instances.

ZeroRedundancyOptimizer
Join()

The peer-to-peer mechanism is also analyzed in the follow-up meeting.

ZeroRedundancyOptimizer

5.3 principle

In the latest document https://pytorch.org/tutorials/advanced/generic_ Pytorch gives some explanations in join.html, which we translate as follows.

For better use, we will introduce classes and support classes and.

Join
Joinable
JoinHook

Note: this part is in the v1.10.0 code.

5.3.1 Joinable

Joinable

First, classes compatible with the context manager must inherit the abstract base class. In particular, the following must be achieved:

Join
Joinable
Joinable
  • join_hook(self, **kwargs) -> JoinHook

This returns an instance of how the joined process should affect collective communication for each iteration performed by.

JoinHook
Joinable
Joinable
  • join_device(self) -> torch.device

This returns the device that the context manager uses to perform collective communication, such as or.

Join
torch.device("cuda:0")
torch.device("cpu")
  • join_process_group(self) -> ProcessGroup

This returns the group of processes that the context manager uses to perform collective communication.

Join

To sum up, be responsible for specific actions, join_ Device and join_ process_ Group is responsible for specific collective communication.

JoinHook

Note that and are required attributes that ensure that the context manager can schedule collective communication between “joined” and “not joined” processes. One usage is to use all reduce to calculate the number of “not joined” processes in each iteration. Another use is to implement the required mechanism, which we will explain below.

join_device
join_process_group
throw_on_early_termination=True

And have inherited and implemented the above methods, which is why we can use them directly in the previous examples.

DistributedDataParallel
ZeroRedundancyOptimizer
Joinable
class DistributedDataParallel(Module, Joinable):

class ZeroRedundancyOptimizer(Optimizer, Joinable):

DDP involves providing data, so it is understandable to inherit joinable. Why do you also need to inherit? This is because it can cooperate with DDP and there are internal collection operations, so it needs to be managed by join.

ZeroRedundancyOptimizer
ZeroRedundancyOptimizer
ZeroRedundancyOptimizer

Class should ensure that the constructor is called because it initializes an instance, which is used internally by the context manager to ensure correctness. Will be saved in each field.

Joinable
Joinable
JoinConfig
JoinConfig
JoinConfig
Joinable    
_join_config

5.3.2JoinHook

JoinHook

Next, let’s decompose the class. Provides two entry points into the context manager:

JoinHook
JoinHook
  • main_hook(self) -> None

When there is a rank that has not been joined, each join rank will call this hook repeatedly. Its purpose is to hide the collective communication performed by each training iteration (for example, in a forward pass, reverse pass and optimizer step), that is, how the joined rank performs collective communication with the non joined rank.

Joinable
  • post_hook(self, is_last_joiner: bool) -> None

Once all ranks are added, the hook will be called. It passes an additional parameter indicating whether this rank is one of the last rank added. This parameter may be useful for synchronization.

bool
is_last_joiner

5.3.2.1 ZeroRedundancyOptimizer

We use the built-in main hook to give a specific example of a hook: because the added rank is still responsible for updating and synchronizing its parameter fragments, the main hook still executes the optimizer steps.

ZeroRedundancyOptimizer
class _ZeROJoinHook(_JoinHook):
    def __init__(self, zero):
        assert isinstance(zero, ZeroRedundancyOptimizer), \
            "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " \
            "instance as the state"
        self.zero = zero
        super().__init__()

    def main_hook(self):
        """
        Performs an optimizer step, which updates the joined process's shard of
        the parameters and broadcasts those parameters.
        """
        self.zero.step()

The step function is as follows:

def step(
    self,
    closure: Optional[Callable[[], float]] = None,
    **kwargs: Any,
) -> Optional[float]:
    _Join.notify_join_context(self) # 这里会通知
    # Check if the model trainability has changed
    is_trainable_mask = self._get_is_trainable_mask()
    if is_trainable_mask != self._is_trainable_mask:
        self._build_param_buckets()
        self._is_trainable_mask = is_trainable_mask

    # Sync the exposed `param_groups` attributes to the local optimizer in
    # case they have been updated
    self._sync_param_groups(self.param_groups, self.optim.param_groups)

    # Run the optimizer step on this shard only
    if closure is not None:
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
    else:
        loss = self.optim.step(**kwargs)

    # Sync all of the updated parameter shards across the ranks
    self._sync_parameters()

    # Sync any updated attributes in the local optimizer to the exposed
    # `param_groups`
    self._sync_param_groups(self.optim.param_groups, self.param_groups)

    return loss

Take another look:

DistributedDataParallel
  • main_ Hook will still do a series of related operations to deceive other rank.
  • The post hook broadcasts the final updated model from one of the last added ranges to ensure that the model is the same in all ranges.
class _DDPJoinHook(_JoinHook):
    def __init__(self, ddp, divide_by_initial_world_size):
        """
        Sets config variables for internal usage.
        """
        ddp.logger._set_uneven_input_join()
        self.ddp = ddp
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
        super().__init__()

    def main_hook(self):
        """
        Shadows the DDP collective communication operations in the forward and
        backward passes.
        """
        ddp = self.ddp
        # Buckets are rebuilt only once during a training period
        ddp.reducer._rebuild_buckets()

        # Schedule a broadcast if we are syncing module buffers in the
        # forward pass
        ddp._check_and_sync_module_buffers()

        # Check if need to sync in the backward pass
        work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
        work.wait()
        should_sync_backwards = work.result()[0].item() != 0
        # Forward parameter sync is disabled in the next iteration if we
        # are skipping gradient sync this iteration, so set
        # `require_forward_param_sync` accordingly
        ddp.require_forward_param_sync = should_sync_backwards
        if not should_sync_backwards:
            return

        # Schedule one allreduce per gradient bucket to match the backward
        # pass allreduce
        ddp._match_all_reduce_for_bwd_pass()

        # Check if we need to allreduce locally unused parameters
        if ddp.find_unused_parameters:
            ddp._match_unused_params_allreduce()

        # Rebuilt parameters are pushed only once during a training period
        ddp.reducer._push_all_rebuilt_params()

    def post_hook(self, is_last_joiner: bool):
        """
        Syncs the final model to ensure that the model is the same across all
        processes.
        """
        self.ddp._sync_final_model(is_last_joiner)

_ sync_ final_ Model the latest model will be broadcast here.

# When running in join model, agrees upon a common rank and broadcast model
# parameters to all other ranks.
def _sync_final_model(self, is_last_joiner):
    # Agree upon the process that will be the authoritative model copy.
    # The current rank is a candidate for being the authoritative copy if
    # is_last_joiner=True. We break ties via picking the larger rank.
    self._authoritative_rank = self._find_common_rank(
        self._distributed_rank, is_last_joiner
    )
    self._sync_params_and_buffers(authoritative_rank=self._authoritative_rank)

5.3.3 Join

Join

Finally, let’s look at how these basic classes fit into the class itself.

Join
  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

As we saw in the previous example, the constructor receives a list of participants in the training loop. These should be classes that perform collective communication in each iteration.

Joinable

If you know that there will be no uneven input, you can set it to. In this case, the context manager becomes similar to. This may also disable join related calculation in the participation list.

enable
bool
False
contextlib.nullcontext()
Joinable

Is a type that can be set to so that each level throws an exception when uneven input is detected. This is useful for situations that do not meet the requirements of the context manager, usually when collective communications from different classes can be interleaved arbitrarily, such as when used with a model with layers. In this case, this parameter should be set to so that the application logic can catch the exception and determine how to proceed.

throw_on_early_termination
bool
True
DistributedDataParallel
SyncBatchNorm
True
  • The core logic appears in the__ exit__ () method, this method will call the main hook of each joinable loop when there are unincorporated rank, and then call their post hook once all rank join. Both the main hook and the back hook iterate in the order passed in by joinables.
  • The context manager requires a heartbeat from an unincorporated process. Therefore, each Joinable class should call Join.notify_ before the collective communication of each iteration. join_ context() 。 The context manager will ensure that only the first incoming joinable actually sends a heartbeat.

5.4 examples

Let’s take a concrete look through an example. In the following code, each rank will print (1) the number of inputs of all ranks seen before join, and (2) the total number of inputs of all ranks.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    # 确定最后join的rank,由于后加入的rank可能不止一个,所以选择rank最大的rank来同步  
    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

Since rank 0 sees 5 inputs and rank 1 sees 6, the following outputs are generated:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

Some key points to emphasize:

  • The counter instance performs an all reduce operation in each iteration, so:
    For the joined rank, its main hook also executes a single all reduce to shadow it for the overall communication. Note that this all reduce calls a tensor of 0, so it has no impact on the overall result.
    Other ranks that are not joined will think that this is still a correct full set operation.
    This handles uneven input.
  • For the joined rank, its main hook also executes a single all reduce to shadow it for the overall communication. Note that this all reduce calls a tensor of 0, so it has no impact on the overall result.
  • Other ranks that are not joined will think that this is still a correct full set operation.
  • This handles uneven input.
  • Counter class in its__ call__ Call join.notify at the beginning of the () method_ join_ Context (), because this is the place for each collection operation (all reduce), you need to notify the context manager here. In this example, there is no join (the finished rank will not be called here).
  • ‘is_last_joiner’参数用于确定post-hooks中的广播源。
  • We will sync_ max_ The count keyword parameter is passed to the context manager, which forwards it to the join hook of ‘counter’.
  • post-hooks之中,会对 self.counter.max_count 进行广播。

0xFF 参考

Pytorch distributed series 3 – what does torch.utils.data.distributed.distributedsampler do during distributed training?

Pytorch distributed series 1 — find out the environment variables related to torch.distributed.launch

How does pytorch distributed series 2 – distributed data parallel synchronize?

Summary of personal practice of pytorch (distributed) data parallel — dataparallel / distributed dataparallel

Pytorch的nn.DataParallel

https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/20

https://pytorch.org/docs/stable/distributed.html

Pytorch source code interpretation of distributed training to understand?

Practical tutorial | pytorch autograd C + + layer implementation

Pytorch automatic differentiation (I)

How does pytorch accelerate data parallel training? Uncover the secrets of distributed Secrets

Pytorch distributed training (II init_process_group)

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

https://pytorch.org/docs/master/notes/ddp.html

https://pytorch.org/tutorials/intermediate/dist_tuto.html

Interpretation of pytorch source code DP & amp; DDP: model parallel and distributed training analysis

Pytorch模型中的parameter与buffer

[pytorch Developer Day 2020] pytorch distributed data parallelism (DDP)

[Chinese subtitle] deeply understand the hook mechanism in pytorch

[Chinese subtitle] in depth interpretation of pytoch autograd

DISTRIBUTED TRAINING WITH UNEVEN INPUTS USING THE JOIN CONTEXT MANAGER

谈谈torch1.10中的ZeroRedundancyOptimizer和Join