您现在的位置是:首页 >技术教程 >【TVM系列六】PackedFunc原理网站首页技术教程

【TVM系列六】PackedFunc原理

生活需要深度 2025-12-08 00:01:02
简介【TVM系列六】PackedFunc原理

一、前言

在TVM中,PackedFunc贯穿了整个Stack,是Python与C++进行互相调用的桥梁,深入理解PackedFunc的数据结构及相应的调用流程对理解整个TVM的代码很有帮助。

二、预备知识

1

ctypes

ctypes 是 Python 的外部函数库。它提供了与 C 兼容的数据类型,并允许调用 DLL 或共享库中的函数。可使用该模块以纯 Python 形式对这些库进行封装。

  • ctypes.byref

有时候 C 函数接口可能由于要往某个地址写入值,或者数据太大不适合作为值传递,从而希望接收一个 指针 作为数据参数类型。这和 传递参数引用 类似。

ctypes 暴露了 byref() 函数用于通过引用传递参数。

  • ctypes.c_void_p

代表 C 中的 void * 类型。该值被表示为整数形式。当给 c_void_p 赋值时,将改变它所指向的内存地址,而不是它所指向的内存区域的内容。 

2

C++ std::function

std::function是一个函数包装器模板,通过std::function对C++中各种可调用实体(普通函数、Lambda表达式、函数指针、以及其它函数对象等)的封装,形成一个新的可调用的std::function对象。它的原型说明为:

T: 通用类型,但实际通用类型模板并没有被定义,只有当T的类型为形如Ret(Args...)的函数类型才能工作。Ret: 调用函数返回值的类型。Args: 函数参数类型。

3

C++关键字decltype

decltype与auto关键字一样,用于进行编译时类型推导,不过它与auto还是有一些区别的。decltype的类型推导并不是像auto一样是从变量声明的初始化表达式获得变量的类型,而是总是以一个普通表达式作为参数,返回该表达式的类型,而且decltype并不会对表达式进行求值。它的使用方法主要有

  • 推导出表达式类型

int i = 4;decltype(i) a; //推导结果为int, a的类型为int
  • 与using/typedef合用,用于定义类型。

using size_t = decltype(sizeof(0));//sizeof(a)的返回值为size_t类型using ptrdiff_t = decltype((int*)0 - (int*)0);using nullptr_t = decltype(nullptr);vector<int >vec;typedef decltype(vec.begin()) vectype;for (vectype i = vec.begin; i != vec.end(); i++){...}

这样和auto一样,也提高了代码的可读性。

  • 重用匿名类型

在C++中,我们有时候会遇上一些匿名类型,如:

struct {  int d;  doubel b;} anon_s;

而借助decltype,我们可以重新使用这个匿名的结构体:

decltype(anon_s) as; //定义了一个上面匿名的结构体
  • 泛型编程中结合auto,用于追踪函数的返回值类型

这也是decltype最大的用途了。

template <typename _Tx, typename _Ty>auto multiply(_Tx x, _Ty y)->decltype(_Tx*_Ty){return x*y;}

三、PackedFunc原理

1

数据结构

  • TVMValue

联合体,在Python与C++互调时的数据类型

typedef union {  int64_t v_int64;  double v_float64;  void* v_handle;  const char* v_str;  DLDataType v_type;  // The data type the tensor can hold  DLDevice v_device;  // A Device for Tensor and operator} TVMValue;
  • TVMPODValue_

内部基类用于处理POD类型的转换,主要重载了强制类型转换运行符,在c++中,类型的名字,包括类的名字本身也是一种运算符,即类型强制转换运算符。

class TVMPODValue_ {public:  operator double() const {...}  operator int64_t() const {...}  operator uint64_t() const {...}  operator int() const {...}  operator bool() const {...}  operator void*() const {...}  operator DLTensor*() const {...}  operator NDArray() const {...}  operator Module() const {...}  operator Device() const {...} // 以上为强制类型转换运行符重载
  int type_code() const { return type_code_; }  template <typename T>  T* ptr() const {return static_cast<T*>(value_.v_handle);}  ...protected:  ...  TVMPODValue_() : type_code_(kTVMNullptr) {}  TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} // 构造函数  TVMValue value_;  int type_code_;};
  • TVMArgValue

继承TVMPODValue_,扩充了更多的类型转换的重载,其中包括了重要的PackedFunc()与TypedPackedFunc()

class TVMArgValue : public TVMPODValue_ { public:  TVMArgValue() {}  TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} // 构造函数  using TVMPODValue_::operator double;  using TVMPODValue_::operator int64_t;  using TVMPODValue_::operator uint64_t;  using TVMPODValue_::operator int;  using TVMPODValue_::operator bool;  using TVMPODValue_::operator void*;  using TVMPODValue_::operator DLTensor*;  using TVMPODValue_::operator NDArray;  using TVMPODValue_::operator Device;  using TVMPODValue_::operator Module;  using TVMPODValue_::AsObjectRef;  using TVMPODValue_::IsObjectRef; // 复用父类的类型转换函数  // conversion operator.  operator std::string() const {...}  operator PackedFunc() const {    if (type_code_ == kTVMNullptr) return PackedFunc();    TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);    return *ptr<PackedFunc>(); // 相当于static_cast<PackedFunc*>(value_.v_handle),将void *v_handle的指针强转为PackedFunc*  }  template <typename FType>  operator TypedPackedFunc<FType>() const {    // TypedPackedFunc类中也重载了PackedFunc(),当调用operator PackedFunc()时会    return TypedPackedFunc<FType>(operator PackedFunc());  }  const TVMValue& value() const { return value_; }
  template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>  inline operator T() const;  inline operator DLDataType() const;  inline operator DataType() const;};
  • TVMRetValue

这个类也是继承自TVMPODValue_类,主要作用是作为存放调用PackedFunc返回值的容器,它和TVMArgValue的区别是,它在析构时会释放源数据。它主要实现的函数为:构造和析构函数;对强制类型转换运算符重载的扩展;对赋值运算符的重载;辅助函数,包括释放资源的Clear函数。

class TVMRetValue : public TVMPODValue_ {public:  // 构造与析构函数  TVMRetValue() {}  ~TVMRetValue() { this->Clear(); }
  // 从父类继承的强制类型转换运算符重载  using TVMPODValue_::operator double;  using TVMPODValue_::operator int64_t;  using TVMPODValue_::operator uint64_t;  ...  // 对强制类型转换运算符重载的扩展  operator PackedFunc() const {...}  template <typename FType>  operator TypedPackedFunc<FType>() const {...}  ...  // 对赋值运算符的重载  TVMRetValue& operator=(TVMRetValue&& other) {...}  TVMRetValue& operator=(double value) {...}  TVMRetValue& operator=(std::nullptr_t value) {...}  ... private:  ...  // 根据type_code_的类型释放相应的源数据  void Clear() {...}}
  • TVMArgs

传给PackedFunc的参数,定义了TVMValue的参数数据、参数类型码以及参数个数,同时重载了[]运算符,方便对于多个参数的情况可以通过下标索引直接获取对应的入参。

class TVMArgs { public:  const TVMValue* values; // 参数数据  const int* type_codes; // 类型码  int num_args; // 参数个数  TVMArgs(const TVMValue* values, const int* type_codes, int num_args)      : values(values), type_codes(type_codes), num_args(num_args) {} // 构造函数  inline int size() const; // 获取参数个数  inline TVMArgValue operator[](int i) const; // 重载[]运算符,通过下标索引获取};
  • PackedFunc

PackedFunc就是通过std::function来实现,std::function最大的好处是可以针对不同的可调用实体形成统一的调用方式:

class PackedFunc { public:  using FType = std::function<void(TVMArgs args, TVMRetValue* rv)>; // 声明FType为函数包装器的别名  // 构造函数  PackedFunc() {}  PackedFunc(std::nullptr_t null) {}   explicit PackedFunc(FType body) : body_(body) {} // 传入FType初始化私有成员body_  template <typename... Args>  inline TVMRetValue operator()(Args&&... args) const; // 运算符()重载,直接传入没有Packed为TVMArgs的参数  inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; // 调用Packed的函数,传入已经Packed的参数  inline FType body() const; // 返回私有成员body_  bool operator==(std::nullptr_t null) const { return body_ == nullptr; } // 重载==运算符,判断PackedFunc是否为空  bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } // 重载!=运算符,判断PackedFunc是否非空 private:  FType body_; // 函数包装器,用于包裹需要Packed的函数};

其中的成员函数实现如下:

// 运算符()重载template <typename... Args> // C++的parameter packinline TVMRetValue PackedFunc::operator()(Args&&... args) const {  const int kNumArgs = sizeof...(Args); // 计算可变输入参数的个数  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;  TVMValue values[kArraySize];  int type_codes[kArraySize];  // 完美转发,遍历每个入参然后通过TVMArgsSetter来赋值  detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);   TVMRetValue rv;  body_(TVMArgs(values, type_codes, kNumArgs), &rv); // 调用包裹函数所指向的函数  return rv;}
// 调用Packed的函数inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); } // 直接调用body_包裹的函数
// 获取函数包装器inline PackedFunc::FType PackedFunc::body() const { return body_; } // 返回函数包装器
  • TypedPackedFunc

TypedPackedFunc是PackedFunc的一个封装,TVM鼓励开发者在使用C++代码开发的时候尽量使用这个类而不是直接使用PackedFunc,它增加了编译时的类型检查,可以作为参数传给PackedFunc,可以给TVMRetValue赋值,并且可以直接转换为PackedFunc。

2

关联流程

Python端导入TVM的动态链接库及调用的流程如下图所示:

图片

(1)TVM的Python代码从python/tvm/__init__.py中开始执行

from ._ffi.base import TVMError, __version__进而调用python/tvm/_ffi/__init__.py导入base及registry相关组件:from . import _pyversionfrom .base import register_errorfrom .registry import register_object, register_func, register_extensionfrom .registry import _init_api, get_global_func

在base.py执行完_load_lib函数后,全局变量_LIB和_LIB_NAME都完成了初始化,其中_LIB是一个ctypes.CDLL类型的变量,它是Python与C++部分进行交互的入口,可以理解为操作TVM动态链接库函数符号的全局句柄,而_LIB_NAME是“libtvm.so”字符串。

(2)导入runtime相关的组件

from .runtime.object import Objectfrom .runtime.ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtlfrom .runtime.ndarray import vpi, rocm, ext_dev, hexagonfrom .runtime import ndarray as nd

这里首先会调用python/tvm/runtime/__init__.py,它会执行:

from .packed_func import PackedFunc

从而定义一个全局的PackedFuncHandle:

PackedFuncHandle = ctypes.c_void_pclass PackedFunc(PackedFuncBase): # 这是个空的子类,实现都在父类_set_class_packed_func(PackedFunc) # 设置一个全局的控制句柄

类PackedFuncBase()在整个流程中起到了最关键的作用:

class PackedFuncBase(object):    __slots__ = ["handle", "is_global"]    # pylint: disable=no-member    def __init__(self, handle, is_global):        self.handle = handle        self.is_global = is_global    def __del__(self):        if not self.is_global and _LIB is not None:            if _LIB.TVMFuncFree(self.handle) != 0:                raise get_last_ffi_error()    def __call__(self, *args):  # 重载了函数调用运算符“()”        temp_args = []        values, tcodes, num_args = _make_tvm_args(args, temp_args)        ret_val = TVMValue()        ret_tcode = ctypes.c_int()        if (            _LIB.TVMFuncCall(self.handle, values, tcodes,                ctypes.c_int(num_args),                ctypes.byref(ret_val),                ctypes.byref(ret_tcode),            )            != 0        ):            raise get_last_ffi_error()        _ = temp_args        _ = args        return RETURN_SWITCH[ret_tcode.value](ret_val)

它重载了__call__()函数,类似于C++中重载了函数调用运算符“()”,内部调用了_LIB.TVMFuncCall(handle),把保存有C++ PackedFunc对象地址的handle以及相关的参数传递进去。

TVMFuncCall的代码如下(函数实现在tvm/src/runtime/c_runtime_api.cc):

int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,                TVMValue* ret_val, int* ret_type_code) {  API_BEGIN();  TVMRetValue rv;  // 强转为PackedFunc *后调用CallPacked(),最终相当于直接调用PackedFunc的包裹函数body_(参看上一小节的PackedFunc类的实现分析)  (*static_cast<const PackedFunc*>(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);  // 处理返回值  ...  API_END();}

(3)初始化C++端API

TVM中Python端的组件接口都会通过以下方式进行初始化注册,如relay中的_make.py:

tvm._ffi._init_api("relay._make", __name__)

而_init_api()会通过模块名称遍历C++注册的全局函数列表,这个列表是由_LIB.TVMFuncListGlobalNames()返回的结果(python/tvm/_ffi/registry.py):

def _init_api_prefix(module_name, prefix):    // 每当导入新模块,全局字典sys.modules将记录该模块,当再次导入该模块时,会直接到字典中查找,从而加快程序运行的速度    module = sys.modules[module_name]     // list_global_func_names()通过_LIB.TVMFuncListGlobalNames()得到函数列表    for name in list_global_func_names():         if not name.startswith(prefix):            continue        fname = name[len(prefix) + 1 :]        target_module = module        if fname.find(".") != -1:            continue        f = get_global_func(name)  // 根据        ff = _get_api(f)        ff.__name__ = fname        ff.__doc__ = "TVM PackedFunc %s. " % fname        setattr(target_module, ff.__name__, ff)

(4)通过_get_global_func()获取C++创建的PackedFunc
_get_global_func中调用了TVMFuncGetGlobal() API:

def _get_global_func(name, allow_missing=False):    handle = PackedFuncHandle()    check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) // new一个PackeFunc对象,通过handle传出来    if handle.value:        return _make_packed_func(handle, False) //    if allow_missing:        return None    raise ValueError("Cannot find global function %s" % name)

而在TVMFuncGetGlobal的实现中,handle最终保存了一个在C++端 new 出来的PackedFunc对象指针:

// src/runtime/registry.ccint TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {  const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name);  *out = new tvm::runtime::PackedFunc(*fp); // C++端创建PackedFunc对象}

最后在Python端创建PackedFunc类,并用C++端 new 出来的PackedFunc对象指针对其进行初始化:

def _make_packed_func(handle, is_global):    """Make a packed function class"""    obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC) // 创建Python端的PackedFunc对象    obj.is_global = is_global    obj.handle = handle // 将C++端的handle赋值给Python端的handle    return obj

由于Python端的PackedFuncBase重载了__call__方法,而__call__方法中调用了C++端的TVMFuncCall(handle),从而完成了从Python端PackedFunc对象的执行到C++端PackedFunc对象的执行的整个流程。

四、总结

本文介绍了PackedFunc的数据结构定义以及Python端与C++端之间通过PackedFunc调用的流程,希望能对读者有所帮助。

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。