Overhead
PyTorch 执行 eager 操作时,例如,torch.add(a, b)
,调度器(c10::Dispatcher
)会根据分派键(DispatchKey
) 来查找并执行 add op 的 op kernel (理解PyTorch分发机制的内部工作原理)。因此,算子注册过程就是在调度器中定义 op,并将 kernel function 注册到 op 的指定分派键条目中。
Torch Library
torch::Library
是算子注册用的 helper,通过它注册的算子有着相同的命名空间、dispatch key等。
TORCH_LIBRARY(myops, m) {
m.def("myadd(Tensor self, Tensor other) -> Tensor");
m.def("mysub(Tensor self, Tensor other) -> Tensor", mysub_func);
m.impl("myadd", myadd_func);
}
m
就是命名空间为 myops
的 library,它通过 m.def
定义了 myadd 和 mysub 这两个 op 的静态信息 schema。mysub 在定义的同时也将 mysub_func
函数注册到 op,而 myadd 的 op kernel 则是通过 m.impl
单独注册的。由于 TORCH_LIBRARY 宏没有指定 dispatch key,因此,这两个 op kernel 都是 CatchAll
函数。
如果要将 kernel function 注册到指定的 dispatch key,需要用到 TORCH_LIBRARY_IMPL 宏:
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
m.impl("myadd", myadd_cuda);
m.impl("mysub", mysub_cuda);
}
所有通过 m
注册的 kernel function 都会注册到 op 的 CUDA
key 条目中,它执行的优先级会比 CatchAll 更高。
OperatorDef
OperatorDef
用于描述调度器中 op 的静态信息,它会提供 registerSchema()
、registerKernel()
方法给 m.def() 和 m.impl() 分别用于注册 op 和 kernel。
Kernel list
通过 m.impl() 注册的 kernel function 会插入到指定 dispatch key 的 kernel list(kernels_
)的头部,而调度器则会从列表中的首元素中获取 kernel。也就是说,PyTorch 允许为 op 的同一个 dispatch key 注册多个 kernel,而新 kernel 会覆盖旧 kernel。
class TORCH_API OperatorEntry final {
...
ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
};
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
auto kern_it = kernels_.find(dispatch_key);
if (kern_it != kernels_.end()) {
TORCH_INTERNAL_ASSERT(!kern_it->second.empty());
TORCH_INTERNAL_ASSERT(kern_it->second.front().kernel.isValid());
return &kern_it->second.front();
}
return nullptr;
}