理解 Relax 抽象
7831
数据流程块#
在 relax 函数中另一个重要的元素是 R.dataflow() 范围注释。
with R.dataflow():
lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((b, 10), dtype="float32"))
R.output(lv2)
在讨论数据流块之前,首先介绍“纯函数”和“副作用”的概念。函数是“纯的”或者说是“无副作用的”,如果它满足以下条件:
它仅从输入中读取数据,并通过输出返回结果。
它不会改变程序的其他部分(比如增加 “全局计数器”)。
例如,所有的 R.call_tir 函数都是纯函数,因为它们仅从输入读取数据并将输出写入另一个新分配的张量。然而,原地操作不是纯函数,换句话说,它们是有副作用的函数,因为它们会改变现有的中间或输入张量。
数据流块是一种方法,用于标记程序的计算图区域。具体来说,在数据流块内部,所有运算都需要是无副作用的。在数据流块外部,运算可以包含副作用。
备注
常见的问题是,为什么需要手动标记数据流块而不是自动推断它们。采取这种方法有两个主要理由:
数据流块的自动推断可能会面临挑战且不够精确,尤其是在处理对打包函数(如 cuBLAS 集成)的调用时。通过手动标记数据流块,可以使得编译器能够准确理解并优化程序的数据流。
许多优化只能在数据流块内应用。例如,融合优化仅限于单个数据流块内的运算。如果编译器错误地推断数据流边界,可能会错过关键的优化机会,从而可能影响程序的性能。
通过允许手动标记数据流块,确保编译器拥有最准确的信息进行处理,从而带来更有效的优化。