pytorch是动态计算的模式,训练过程分为前向传播和后向传播,前向传播的顺序由计算代码确定,后向传播随着前向传播过程生成计算图,那么在pytorch中,后向传播的具体实现是什么样的呢。
数据结构
Node类表示autograd图中的节点,Edge是autograd的边,连接节点。
struct TORCH_API Node : std::enable_shared_from_this<Node> {
/// Performs the `Node`'s actual operation.
virtual variable_list apply(variable_list&& inputs) = 0;
edge_list next_edges_;
};
struct Edge {
Edge() noexcept : function(nullptr), input_nr(0) {}
Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept
: function(std::move(function_)), input_nr(input_nr_) {}
/// Convenience method to test if an edge is valid.
bool is_valid() const noexcept {
return function != nullptr;
}
// Required for use in associative containers.
bool operator==(const Edge& other) const noexcept {
return this->function == other.function && this->input_nr == other.input_nr;
}
bool operator!=(const Edge& other) const noexcept {
return !(*this == other);
}
/// The function this `Edge` points to.
std::shared_ptr<Node> function;
/// The identifier of a particular input to the function.
uint32_t input_nr;
};
每个Operator都会有正向函数和反向梯度函数, 在pytorch中,每个operator的反向梯度函数都是一个继承Node的结构体,例如mean operator的梯度函数是
struct TORCH_API MeanBackward0 : public TraceableFunction {
using TraceableFunction::TraceableFunction;
variable_list apply(variable_list&& grads) override;
std::string name() const override { return "MeanBackward0"; }
void release_variables() override {
}
std::vector<int64_t> self_sizes;
int64_t self_numel = 0;
at::ScalarType self_scalar_type;
};
此外,它的apply方法是反向过程的具体实现
variable_list MeanBackward0::apply(variable_list&& grads) {
IndexRangeGenerator gen;
auto self_ix = gen.range(1);
variable_list grad_inputs(gen.size());
const auto& grad = grads[0];
bool any_grad_defined = any_variable_defined(grads);
if (should_compute_output({ self_ix })) {
auto grad_result = any_grad_defined ? (grad.expand(self_sizes).to(self_scalar_type) / self_numel) : Tensor();
copy_range(grad_inputs, self_ix, grad_result);
}
return grad_inputs;
}
构建autograd图过程
构建autograd图发生在前向传播过程中
Tensor类有一个成员变量grad_fn,理解为梯度函数, 表示生成tensor的操作的反向函数
const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
梯度函数是在生成tensor的时候创建并赋值的,以下面的mean函数为例
mean函数是前向函数,
首先,在8行到15行,创建了MeanBackward0梯度函数对象
其次,在22行到26行,执行前向函数的计算过程,获得输出tensor result
最后,35行,将梯度函数对象赋值给result
另外第11行确定了图节点Node之间的关系,第11行在self和result之间构成了一条边,并且self是result的下一个节点。在前向传播过程中,是由self->result的计算顺序,而后向传播则是先计算result的梯度,然后再计算self的梯度,是result->self的顺序
at::Tensor mean(c10::DispatchKeySet ks, const at::Tensor & self, c10::optional<at::ScalarType> dtype) {
auto& self_ = unpack(self, "self", 0);
auto _any_requires_grad = compute_requires_grad( self );
(void)_any_requires_grad;
auto _any_has_forward_grad_result = isFwGradDefined(self);
(void)_any_has_forward_grad_result;
std::shared_ptr<MeanBackward0> grad_fn;
if (_any_requires_grad) {
grad_fn = std::shared_ptr<MeanBackward0>(new MeanBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self ));
grad_fn->self_sizes = self.sizes().vec();
grad_fn->self_numel = self.numel();
grad_fn->self_scalar_type = self.scalar_type();
}
#ifndef NDEBUG
c10::optional<Storage> self__storage_saved =
self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> self__impl_saved;
if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
#endif
auto _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return at::redispatch::mean(ks & c10::after_autograd_keyset, self_, dtype);
})();
auto result = std::move(_tmp);
#ifndef NDEBUG
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
if (result.has_storage()) AT_ASSERT(result.storage().use_count() == 1, "function: mean");
AT_ASSERT(result.use_count() <= 1, "function: mean");
#endif
if (grad_fn) {
set_history(flatten_tensor_args( result ), grad_fn);
}
if (_any_has_forward_grad_result) {
auto self_t_raw = toNonOptFwGrad(self);
auto self_t = self_t_raw.defined() ? self_t_raw : at::zeros_like(toNonOptTensor(self));
auto result_new_fw_grad = at::mean(self_t, dtype);
if (result_new_fw_grad.defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
result._set_fw_grad(result_new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
}
return result;
}
后向传播
当前向传播结束后,梯度函数之间的关系已经建立起来,运行loss.backward()即可执行后向传播过程
具体的代码实现顺序如下
- 从顶层的节点开始执行后向传播
auto Engine::execute(const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs) -> variable_list {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
'''
auto graph_task = std::make_shared<GraphTask>(
/* keep_graph */ keep_graph,
/* create_graph */ create_graph,
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
/* cpu_ready_queue */ local_ready_queue);
// If we receive a single root, skip creating extra root node
bool skip_dummy_node = roots.size() == 1;
auto graph_root = skip_dummy_node ?
roots.at(0).function :
std::make_shared<GraphRoot>(roots, inputs);
auto min_topo_nr = compute_min_topological_nr(outputs);
// Now compute the dependencies for all executable functions
compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);
if (!outputs.empty()) {
graph_task->init_to_execute(*graph_root, outputs, accumulate_grad, min_topo_nr);
}
// Queue the root
if (skip_dummy_node) {
InputBuffer input_buffer(roots.at(0).function->num_inputs());
auto input = inputs.at(0);
const auto input_stream = InputMetadata(input).stream();
const auto opt_next_stream = roots.at(0).function->stream(c10::DeviceType::CUDA);
input_buffer.add(roots.at(0).input_nr,
std::move(input),
input_stream,
opt_next_stream);
execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
} else {
execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
}
// Avoid a refcount bump for the Future, since we check for refcount in
// DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1)
// in dist_engine.cpp).
auto& fut = graph_task->future_result_;
fut->wait();
return fut->value().toTensorVector();
}
- 一个梯度函数Node的运行封装成NodeTask,放入队列中处理,例如先将根节点放入队列,然后从队列中拿出根节点处理,然后将下面的节点放入队列,依次拿出...
c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root,
InputBuffer&& input_buffer) {
initialize_device_threads_pool();
// Lock mutex for GraphTask.
std::unique_lock<std::mutex> lock(graph_task->mutex_);
auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());
// worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the
// autograd engine with corresponding GraphTask, and its NOT a re-entrant call
if (worker_device == NO_DEVICE) {
// We set the worker_device to CPU_DEVICE only if worker_device was previously
// NO_DEVICE. Setting it to CPU afterwards allow us to detect whether this is
// a re-entrant call or not.
set_device(CPU_DEVICE);
// set the graph_task owner to the current device
graph_task->owner_ = worker_device;
// Now that all the non-thread safe fields of the graph_task have been populated,
// we can enqueue it.
queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
// The owning thread start to drive the engine execution for any CPU task that
// was just pushed or will be added later from other worker threads
lock.unlock();
thread_main(graph_task);
TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
// reset the worker_device after the completion of the graph_task, this is so
// that the initial state of the engine remains the same across every backward()
// or grad() call, we don't need to reset local_ready_queue as we could possibly
// reuse it for new backward calls.
worker_device = NO_DEVICE;
} else {
// If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
// backward call from that device.
graph_task->owner_ = worker_device;
// Now that all the non-thread safe fields of the graph_task have been populated,
// we can enqueue it.
queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
if (current_depth >= max_recursion_depth_) {
// See Note [Reentrant backwards]
// If reached the max depth, switch to a different thread
add_thread_pool_task(graph_task);
} else {
// Total depth needs to be updated only in this codepath, since it is
// not used in the block above (when we call add_thread_pool_task).
// In the codepath above, GraphTask.reentrant_depth_ is used to
// bootstrap total_depth in the other thread.
++total_depth;
// Get back to work while we wait for our new graph_task to
// complete!
++current_depth;
lock.unlock();
thread_main(graph_task);
--current_depth;
--total_depth;
// The graph task should have completed and the associated future should
// be marked completed as well since 'thread_main' above is a call
// blocking an autograd engine thread.
TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
}
}
// graph_task_exec_post_processing is done when the Future is marked as
// completed in mark_as_completed_and_run_post_processing.
return graph_task->future_result_;
}
- thread_main方法的while循环每次有队列中拿取节点处理
auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
// When graph_task is nullptr, this is a long running thread that processes
// tasks (ex: device threads). When graph_task is non-null (ex: reentrant
// backwards, user thread), this function is expected to exit once that
// graph_task complete.
// local_ready_queue should already been initialized when we get into thread_main
TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr);
while (graph_task == nullptr || !graph_task->future_result_->completed()) {
// local_graph_task represents the graph_task we retrieve from the queue.
// The outer graph_task represents the overall graph_task we need to execute
// for reentrant execution.
std::shared_ptr<GraphTask> local_graph_task;
{
// Scope this block of execution since NodeTask is not needed after this
// block and can be deallocated (release any references to grad tensors
// as part of inputs_).
NodeTask task = local_ready_queue->pop();
// This will only work if the worker is running a non backward task
// TODO Needs to be fixed this to work in all cases
if (task.isShutdownTask_) {
C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
break;
}
if (!(local_graph_task = task.base_.lock())) {
// GraphTask for function is no longer valid, skipping further
// execution.
continue;
}
if (task.fn_ && !local_graph_task->has_error_.load()) {
// Set the ThreadLocalState before calling the function.
// NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
// always saves ThreadLocalState without grad_mode.
at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
try {
// The guard sets the thread_local current_graph_task on construction
// and restores it on exit. The current_graph_task variable helps
// queue_callback() to find the target GraphTask to append final
// callbacks.
GraphTaskGuard guard(local_graph_task);
NodeGuard ndguard(task.fn_);
{
RECORD_FUNCTION(
c10::str(
"autograd::engine::evaluate_function: ",
task.fn_.get()->name()),
std::vector<c10::IValue>());
evaluate_function(
local_graph_task,
task.fn_.get(),
task.inputs_,
local_graph_task->cpu_ready_queue_);
}
} catch (std::exception& e) {
thread_on_exception(local_graph_task, task.fn_, e);
}
}
}
// Decrement the outstanding tasks.
--local_graph_task->outstanding_tasks_;
// Check if we've completed execution.
if (local_graph_task->completed()) {
local_graph_task->mark_as_completed_and_run_post_processing();
auto base_owner = local_graph_task->owner_;
// The current worker thread finish the graph_task, but the owning thread
// of the graph_task might be sleeping on pop() if it does not have work.
// So we need to send a dummy function task to the owning thread just to
// ensure that it's not sleeping, so that we can exit the thread_main.
// If it has work, it might see that graph_task->outstanding_tasks_ == 0
// before it gets to the task, but it's a no-op anyway.
//
// NB: This is not necessary if the current thread is the owning thread.
if (worker_device != base_owner) {
// Synchronize outstanding_tasks_ with queue mutex
std::atomic_thread_fence(std::memory_order_release);
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
}
}
}
}
- evaluation_function主要做两件事,一个是调用call_function执行梯度函数,另一个是将后续节点依据情况放入队列
void Engine::evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputs,
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
// The InputBuffer::adds that supplied incoming grads took pains to
// ensure they're safe to consume in the context of the present
// func's stream (if applicable). So we guard onto that stream
// before working with the grads in any capacity.
const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
// If exec_info_ is not empty, we have to instrument the execution
auto& exec_info_ = graph_task->exec_info_;
if (!exec_info_.empty()) {
auto& fn_info = exec_info_.at(func);
if (auto* capture_vec = fn_info.captures_.get()) {
// Lock mutex for writing to graph_task->captured_vars_.
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto& capture : *capture_vec) {
auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
captured_grad = inputs[capture.input_idx_];
for (auto& hook : capture.hooks_) {
captured_grad = (*hook)(captured_grad);
}
if (opt_parent_stream) {
// No need to take graph_task->mutex_ here, we already hold it
graph_task->leaf_streams.emplace(*opt_parent_stream);
}
}
}
if (!fn_info.needed_) {
// Skip execution if we don't need to execute the function.
return;
}
}
auto outputs = call_function(graph_task, func, inputs);
auto& fn = *func;
if (!graph_task->keep_graph_) {
fn.release_variables();
}
int num_outputs = outputs.size();
if (num_outputs == 0) { // Note: doesn't acquire the mutex
// Records leaf stream (if applicable)
// See Note [Streaming backwards]
if (opt_parent_stream) {
std::lock_guard<std::mutex> lock(graph_task->mutex_);
graph_task->leaf_streams.emplace(*opt_parent_stream);
}
return;
}
if (AnomalyMode::is_enabled()) {
AutoGradMode grad_mode(false);
for (const auto i : c10::irange(num_outputs)) {
auto& output = outputs[i];
at::OptionalDeviceGuard guard(device_of(output));
if (output.defined() && isnan(output).any().item<uint8_t>()) {
std::stringstream ss;
ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
throw std::runtime_error(ss.str());
}
}
}
// Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto i : c10::irange(num_outputs)) {
auto& output = outputs[i];
const auto& next = fn.next_edge(i);
if (!next.is_valid()) continue;
// Check if the next function is ready to be computed
bool is_ready = false;
auto& dependencies = graph_task->dependencies_;
auto it = dependencies.find(next.function.get());
if (it == dependencies.end()) {
auto name = next.function->name();
throw std::runtime_error(std::string("dependency not found for ") + name);
} else if (--it->second == 0) {
dependencies.erase(it);
is_ready = true;
}
auto& not_ready = graph_task->not_ready_;
auto not_ready_it = not_ready.find(next.function.get());
if (not_ready_it == not_ready.end()) {
// Skip functions that aren't supposed to be executed
if (!exec_info_.empty()) {
auto it = exec_info_.find(next.function.get());
if (it == exec_info_.end() || !it->second.should_execute()) {
continue;
}
}
// No buffers have been allocated for the function
InputBuffer input_buffer(next.function->num_inputs());
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
} else {
not_ready.emplace(next.function.get(), std::move(input_buffer));
}
} else {
// The function already has a buffer
auto &input_buffer = not_ready_it->second;
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
not_ready.erase(not_ready_it);
}
}
}
}
- call_function 由输入和函数,计算梯度作为输出,输出将通过节点之间关系作为下一个梯度函数的输入
static variable_list call_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputBuffer) {
CheckpointValidGuard cpvguard(graph_task);
auto& fn = *func;
auto inputs =
call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
if (!graph_task->keep_graph_) {
fn.will_release_variables();
}
const auto has_post_hooks = !fn.post_hooks().empty();
variable_list outputs;
if (has_post_hooks) {
// In functions/accumulate_grad.cpp, there is some logic to check the
// conditions under which the incoming gradient can be stolen directly
// (which elides a deep copy) instead of cloned. One of these conditions
// is that the incoming gradient's refcount must be 1 (nothing else is
// referencing the same data). Stashing inputs_copy here bumps the
// refcount, so if post hooks are employed, it's actually still ok for
// accumulate_grad.cpp to steal the gradient if the refcount is 2.
//
// "new_grad.use_count() <= 1 + !post_hooks().empty()" in
// accumulate_grad.cpp accounts for this, but also creates a silent
// dependency between engine.cpp (ie, this particular engine
// implementation) and accumulate_grad.cpp.
//
// If you change the logic here, make sure it's compatible with
// accumulate_grad.cpp.
auto inputs_copy = inputs;
outputs = fn(std::move(inputs_copy));
} else {
outputs = fn(std::move(inputs));
}
validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "Function " << fn.name() << " returned an " << msg;
return ss.str();
});
if(has_post_hooks){
// NOLINTNEXTLINE(bugprone-use-after-move)
return call_post_hooks(fn, std::move(outputs), inputs);
}
return outputs;
}