- 表达式模板是为了支持一种数值数组的类引入的技术。如希望可以像内置类型一样对数组进行下列操作
Array<double> x(1000), y(1000);
...
x = 1.2*x + x*y;
- 要获得高效率,同时支持上面这种紧凑写法,则需要通过表达式模板来完成。表达式模板和metaprogramming是互补的,metaprogramming主要用于小的、大小固定的数组,表达式模板则用于能在运行期确定大小、中等大小的数组
临时变量和分割循环
- 了解表达式模板前,先看一种简单的数值数组操作的模板实现
// exprtmpl/sarray1.hpp
#include <cstddef>
#include <cassert>
template<typename T>
class SArray { // simple array
public:
// create array with initial size
explicit SArray (size_t s)
: storage(new T[s]), storage_size(s) {
init();
}
// copy constructor
SArray (SArray<T> const& orig)
: storage(new T[orig.size()]), storage_size(orig.size()) {
copy(orig);
}
// destructor: free memory
~SArray() {
delete[] storage;
}
// assignment operator
SArray<T>& operator= (SArray<T> const& orig) {
if (&orig!=this) {
copy(orig);
}
return *this;
}
// return size
size_t size() const {
return storage_size;
}
// index operator for constants and variables
T const& operator[] (std::size_t idx) const {
return storage[idx];
}
T& operator[] (std::size_t idx) {
return storage[idx];
}
protected:
// init values with default constructor
void init() {
for (std::size_t idx = 0; idx<size(); ++idx) {
storage[idx] = T();
}
}
// copy values of another array
void copy (SArray<T> const& orig) {
assert(size()==orig.size());
for (std::size_t idx = 0; idx<size(); ++idx) {
storage[idx] = orig.storage[idx];
}
}
private:
T* storage; // storage of the elements
std::size_t storage_size; // number of elements
};
// exprtmpl/sarrayops1.hpp
// addition of two SArrays
template<typename T>
SArray<T> operator+ (SArray<T> const& a, SArray<T> const& b)
{
assert(a.size()==b.size());
SArray<T> result(a.size());
for (std::size_t k = 0; k<a.size(); ++k) {
result[k] = a[k]+b[k];
}
return result;
}
// multiplication of two SArrays
template<typename T>
SArray<T> operator* (SArray<T> const& a, SArray<T> const& b)
{
assert(a.size()==b.size());
SArray<T> result(a.size());
for (std::size_t k = 0; k<a.size(); ++k) {
result[k] = a[k]*b[k];
}
return result;
}
// multiplication of scalar and SArray
template<typename T>
SArray<T> operator* (T const& s, SArray<T> const& a)
{
SArray<T> result(a.size());
for (std::size_t k = 0; k<a.size(); ++k) {
result[k] = s*a[k];
}
return result;
}
// exprtmpl/sarray1.cpp
#include "sarray1.hpp"
#include "sarrayops1.hpp"
int main()
{
SArray<double> x(1000), y(1000);
...
x = 1.2*x + x*y;
}
- 但上面的实现非常低效,有两方面原因
- 每个运算符操作(除了赋值运算符)至少要生成一个临时数组,例子中编译器不执行任何附加的临时拷贝,也至少会生成3个大小为1000的临时数组
- 运算符程序每次使用都要求对实参和结果数组进行额外遍历,例子中只生成一个SArray对象,需要读取6000次double值,写入4000次double值
tmp1 = 1.2*x; // 循环1000次元素操作,再加上创建和删除tmp1
tmp2 = x*y // 循环1000次元素操作,再加上创建和删除tmp2
tmp3 = tmp1+tmp2; // 循环1000次读写操作,再加上创建和删除tmp3
x = tmp3; // 1000次读操作和写操作
- 对于元素很多的数组,没有足够的内存容纳这些临时对象,每个数值数组程序库的实现都会面临这个问题,因此通常使用computed assignments(如+=、*=)来代替前面的赋值运算符,这样不需要创建任何临时对象
// exprtmpl/sarrayops2.hpp
// additive assignment of SArray
template<class T>
SArray<T>& SArray<T>::operator+= (SArray<T> const& b)
{
assert(size()==orig.size());
for (std::size_t k = 0; k<size(); ++k) {
(*this)[k] += b[k];
}
return *this;
}
// multiplicative assignment of SArray
template<class T>
SArray<T>& SArray<T>::operator*= (SArray<T> const& b)
{
assert(size()==orig.size());
for (std::size_t k = 0; k<size(); ++k) {
(*this)[k] *= b[k];
}
return *this;
}
// multiplicative assignment of scalar
template<class T>
SArray<T>& SArray<T>::operator*= (T const& s)
{
for (std::size_t k = 0; k<size(); ++k) {
(*this)[k] *= s;
}
return *this;
}
// exprtmpl/sarray2.cpp
#include "sarray2.hpp"
#include "sarrayops1.hpp"
#include "sarrayops2.hpp"
int main()
{
SArray<double> x(1000), y(1000);
//...
// process x = 1.2*x + x*y
SArray<double> tmp(x);
tmp *= y;
x *= 1.2;
x += tmp;
}
- 显然使用computed assignments有下列明显缺点
- 符号变得不雅观
- 仍要创建一个非必要的局部变量tmp
- 循环被分割成多个操作,这里是3个,意味着要对double进行6000次读和4000次写操作
- 实际上所期望的操作是针对数组每个下标,只对表达式进行一次理想循环
int main()
{
SArray<double> x(1000), y(1000);
...
for (int idx = 0; idx<x.size(); ++idx) {
x[idx] = 1.2*x[idx] + x[idx]*y[idx];
}
}
- 现在就不需要局部数组了,每次迭代只需要进行两次读(x[idx] and y[idx])和一次写(x[k]),在手工循环中总共只需要2000次读和1000次写操作。最后我们希望既能得到高性能,又不需要手工循环,需要用到下面的技术,让代码更优雅且减少错误产生
在模板实参中编码表达式
- 对前面的问题,一个很好的解决方法是,直到看到整个表达式(上例中为调用赋值运算符时)才对表达式各部分求值,因此求值前必须记录每个对象和该对象上的每个操作,这些操作在编译期已经确定,因此可以用模板实参编码
1.2*x + x*y;
- 上面的表达式中,1.2*x不是一个新的数组,而是一个用于表示x的每个值都乘1.2的对象,x*y表示x的每个元素乘y相应元素,最后需要结果数组值时才进行计算,即只是先存储用于后来求值的一种表示,并没有真正进行计算。把上面表达式转为一个具有如下类型的对象
A_Add<A_Mult<A_Scalar<double>,Array<double>>,
A_Mult<Array<double>,Array<double>>>
Tree representation of expression 1.2*x+x*y
表达式模板的操作数
- 为了完整表示整个表达式,一方面在每个A_Add和A_Mult对象中,必须存储指向实参的引用,另一方面在A_Scalar对象中需要记录这个表示放大倍数的值或引用,下面是对这些操作数的可行定义
// exprtmpl/exprops1.hpp
#include <cstddef>
#include <cassert>
// include helper class traits template to select whether to refer to an
// ``expression template node'' either ``by value'' or ``by reference.''
#include "exprops1a.hpp"
// class for objects that represent the addition of two operands
template <typename T, typename OP1, typename OP2>
class A_Add {
private:
typename A_Traits<OP1>::ExprRef op1; // first operand
typename A_Traits<OP2>::ExprRef op2; // second operand
public:
// constructor initializes references to operands
A_Add (OP1 const& a, OP2 const& b)
: op1(a), op2(b) {
}
// compute sum when value requested
T operator[] (size_t idx) const {
return op1[idx] + op2[idx];
}
// size is maximum size
std::size_t size() const {
assert (op1.size()==0 || op2.size()==0
|| op1.size()==op2.size());
return op1.size()!=0 ? op1.size() : op2.size();
}
};
// class for objects that represent the multiplication of two operands
template <typename T, typename OP1, typename OP2>
class A_Mult {
private:
typename A_Traits<OP1>::ExprRef op1; // first operand
typename A_Traits<OP2>::ExprRef op2; // second operand
public:
// constructor initializes references to operands
A_Mult (OP1 const& a, OP2 const& b)
: op1(a), op2(b) {
}
// compute product when value requested
T operator[] (size_t idx) const {
return op1[idx] * op2[idx];
}
// size is maximum size
std::size_t size() const {
assert (op1.size()==0 || op2.size()==0
|| op1.size()==op2.size());
return op1.size()!=0 ? op1.size() : op2.size();
}
};
- 由代码可见,增加了下标运算符和查询容量大小的操作,从而可以根据该对象子节点(见上图)的相应操作来计算该节点的大小和每个元素的值。对于只涉及到数组的操作,结果数组的大小是其中某个操作数的大小,然而对于同时涉及到数组和scalar的操作,结果数组的大小就是操作数数组的大小。为了区分数组操作数和scalar操作数,假定scalar大小为0,模板A_Scalar定义如下
// exprtmpl/exprscalar.hpp
// class for objects that represent scalars
template <typename T>
class A_Scalar {
private:
T const& s; // value of the scalar
public:
// constructor initializes value
constexpr A_Scalar (T const& v)
: s(v) {
}
// for index operations the scalar is the value of each element
constexpr T const& operator[] (std::size_t) const {
return s;
}
// scalars have zero as size
constexpr std::size_t size() const {
return 0;
};
};
- 由上面代码可以看出,A_Scalar模板也提供了一个索引运算符。在表达式内部,A_Scalar表示的是一个每个索引都相应相同scalar值的数组。运算符类还使用了一个辅助类A_Traits来定义操作数成员
typename A_Traits<OP1>::ExprRef op1; // first operand
typename A_Traits<OP2>::ExprRef op2; // second operand
- 这种做法是必要的,通常可以把这些操作数声明为引用类型,因为大多数局部节点在顶层表达式绑定,生命期能延续到完整表达式的求值。唯一例外的是A_Scalar节点,它在运算符函数内部绑定,不能一直存在到完整表达式的求值,因此需要传值拷贝而非传引用,即需要以下性质的成员
// 通常情况下是const&
OP1 const& op1; // refer to first operand by reference
OP2 const& op2; // refer to second operand by reference
// 但对scalar值是普通值
OP1 op1; // refer to first operand by value
OP2 op2; // refer to second operand by value
- 这就要用到trait class,它定义了一个针对大多数const引用的基本模板,同时也定义了针对scalar的特化
// exprtmpl/exprops1a.hpp
/* helper traits class to select how to refer to an ''expression template node''
* - in general: by reference
* - for scalars: by value
*/
template <typename T> class A_Scalar;
// primary template
template <typename T>
class A_Traits {
public:
using ExprRef = T const&; // type to refer to is constant reference
};
// partial specialization for scalars
template <typename T>
class A_Traits<A_Scalar<T>> {
public:
using ExprRef = A_Scalar<T>; // type to refer to is ordinary value
};
- 另外,如果A_Scalar对象引用的是顶层定义的scalar,对这些scalar也可以用引用类型
Array类型
- 既然能使用轻量级的表达式模板对表达式编码,下面创建一个Array类型,它既能针对占用实际内存的数组,同时也适用于表达式模板。接口设计上应该与占用存储空间的真实数组相似,同时要与基于数组的表达式具有相同表示,Array模板声明如下
template<typename T, typename Rep = SArray<T>>
class Array;
- 上面代码中,Rep类型要么是SArray(前提是Array必须是一个占用实际存储空间的数组),要么是一个用于编码表达式的嵌套template-id,如A_Add和A_Mult。我们将用同一种方式处理这两种途径产生的Array实例化体,以简化后期编码,如果用A_Mult等类型替换Rep,某些成员不能被实例化,但实际中Array模板的定义不需要声明用于区分上面两种情况(即SArray和template-id)的特化,下面是一个定义,下面用到了decltype(auto),处理数组下标很方便
// exprtmpl/exprarray.hpp
#include <cstddef>
#include <cassert>
#include "sarray1.hpp"
template <typename T, typename Rep = SArray<T>>
class Array {
private:
Rep expr_rep; // (access to) the data of the array
public:
// create array with initial size
explicit Array (std::size_t s)
: expr_rep(s) {
}
// create array from possible representation
Array (Rep const& rb)
: expr_rep(rb) {
}
// assignment operator for same type
Array& operator= (Array const& b) {
assert(size()==b.size());
for (std::size_t idx = 0; idx<b.size(); ++idx) {
expr_rep[idx] = b[idx];
}
return *this;
}
// assignment operator for arrays of different type
template<typename T2, typename Rep2>
Array& operator= (Array<T2, Rep2> const& b) {
assert(size()==b.size());
for (std::size_t idx = 0; idx<b.size(); ++idx) {
expr_rep[idx] = b[idx];
}
return *this;
}
// size is size of represented data
std::size_t size() const {
return expr_rep.size();
}
// index operator for constants and variables
decltype(auto) operator[] (std::size_t idx) const {
assert(idx<size());
return expr_rep[idx];
}
T& operator[] (std::size_t idx) {
assert(idx<size());
return expr_rep[idx];
}
// return what the array currently represents
Rep const& rep() const {
return expr_rep;
}
Rep& rep() {
return expr_rep;
}
};
- 这里的许多操作都只是简单委托给所含的Rep对象,但拷贝另一个数组时必须考虑,另一个数组是否基于表达式模板,因此需要根据Rep的表示对拷贝运算符进行参数化,即声明针对两种不同情况的赋值运算符
运算符
- 目前只是实现了用于代表运算符的、针对数值Array模板的运算符操作(如A_Add),但没实现运算符本身(如+)。正如前面所说,这些运算符只是用于代表表达式模板对象,实际上并不对结果数组求值。显然对每个普通的二元运算符,必须实现三个版本,即array-array,array-scalar, scalar-array,如为了计算前面的表达式初始值需要用到下面的运算符
// exprtmpl/exprops2.hpp
// addition of two Arrays
template <typename T, typename R1, typename R2>
Array<T,A_Add<T,R1,R2>>
operator+ (Array<T,R1> const& a, Array<T,R2> const& b) {
return Array<T,A_Add<T,R1,R2>>
(A_Add<T,R1,R2>(a.rep(),b.rep()));
}
// multiplication of two Arrays
template <typename T, typename R1, typename R2>
Array<T, A_Mult<T,R1,R2>>
operator* (Array<T,R1> const& a, Array<T,R2> const& b) {
return Array<T,A_Mult<T,R1,R2>>
(A_Mult<T,R1,R2>(a.rep(), b.rep()));
}
// multiplication of scalar and Array
template <typename T, typename R2>
Array<T, A_Mult<T,A_Scalar<T>,R2>>
operator* (T const& s, Array<T,R2> const& b) {
return Array<T,A_Mult<T,A_Scalar<T>,R2>>
(A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep()));
}
- 这些运算符声明看起来复杂,实际上函数做的工作不多,如对两个数组的加法运算符,首先生成一个用于A_Add<>对象用于表示运算符和操作数
A_Add<T,R1,R2>(a.rep(),b.rep())
- 并把这个对象封装到一个数组中,从而可以借助数组来操作这个运算结果,实际上其他对象也是这样处理的
return Array<T,A_Add<T,R1,R2>> (... );
- 对scalar乘法使用了A_Scalar模板创建A_Mult对象
A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep())
return Array<T,A_Mult<T,A_Scalar<T>,R2>> (... );
- 其他二元运算符实现类似,也可以用宏来声明这些运算符,从而只需要使用较少的代码
回顾
- 对前面的例子,现在进行一个自顶向下的回顾。下面是要分析的代码
int main()
{
Array<double> x(1000), y(1000);
...
x = 1.2*x + x*y;
}
- 由于x和y的定义中省略了Rep实参,所以该参数使用默认值SArray<double>,因此x和y是占用真实内存的数组,也就是说它们不只是用于记录操作。当解析表达式1.2*x + x*y时,编译器首先应用最左边的*,它是一个scalar-array运算符,于是重载解析规则选择operator*的scalar-array形式
template<typename T, typename R2>
Array<T, A_Mult<T,A_Scalar<T>,R2>>
operator* (T const& s, Array<T,R2> const& b) {
return Array<T,A_Mult<T,A_Scalar<T>,R2>>
(A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep()));
}
- 其中操作数类型是double和Array<double, SArray<double>>,因此实际的结果类型是Array<double, A_Mult<double, A_Scalar<double>, SArray<double>>>,而结果值是一个构造自double值1.2的A_Scalar<double>对象和一个表示对象x的SArrayr<double>对象
- 接着对第二个乘法求值,x*y是一个array-array操作,使用相应的operator*
template <typename T, typename R1, typename R2>
Array<T, A_Mult<T,R1,R2>>
operator* (Array<T,R1> const& a, Array<T,R2> const& b) {
return Array<T,A_Mult<T,R1,R2>>
(A_Mult<T,R1,R2>(a.rep(), b.rep()));
}
- 而两个操作数的类型都是Array<double, SArray<double>>,因此结果类型为Array<double, A_Mult<double, SArray<double>, SArray<double>>>,这次A_Mult所封装的两个参数对象都引用了一个SArray<double>,代表一个用于表示x对象,另一个用于表示y对象
- 最后对+运算符求值,依然是array-array操作 ,操作数类型是上面推断的类型,因此调用针对array-array的oprator+
template <typename T, typename R1, typename R2>
Array<T,A_Add<T,R1,R2>>
operator+ (Array<T,R1> const& a, Array<T,R2> const& b) {
return Array<T,A_Add<T,R1,R2> >
(A_Add<T,R1,R2>(a.rep(),b.rep()));
}
- 其中用double替换T,R1替换为A_Mult<double, A_Scalar<double>, SArray<double> >,R2替换为A_Mult<double, SArray<double>, SArray<double> >,最终赋值运算符右边的表达式类型为
Array<double,
A_Add<double,
A_Mult<double, A_Scalar<double>, SArray<double>>,
A_Mult<double, SArray<double>, SArray<double>>>>
template <typename T, typename Rep = SArray<T> >
class Array {
public:
...
// assignment operator for arrays of different type
template<typename T2, typename Rep2>
Array& operator= (Array<T2, Rep2> const& b) {
assert(size()==b.size());
for (std::size_t idx = 0; idx<b.size(); ++idx) {
expr_rep[idx] = b[idx];
}
return *this;
}
...
};
- 其中赋值运算符将会使用右边Array(即b)的下标运算符来计算目标数组x的每个元素,右边Array的实际类型为
A_Add<double,
A_Mult<double, A_Scalar<double>, SArray<double>>,
A_Mult<double, SArray<double>, SArray<double>>>>
- 如果仔细跟踪这个下标操作,对一个给定的下标x将得到
(1.2*x[idx]) + (x[idx]*y[idx])
表达式模板赋值
- 对于一个Rep实参基于A_Mult或A_Add表达式模板的数组,是不能为该数组实例化写操作的(即编写a+b=c的式子毫无意义),但可以编写其他的表达式模板从而能对这些表达式模板的结果赋值,如以具有整数值数组为下标的索引操作通常会涉及到子集的选择,即x[y] = 2*x[y]等价于
for (std::size_t idx = 0; idx<y.size(); ++idx) {
x[y[idx]] = 2*x[y[idx]];
}
- 为了使上面这种写法可行,必须令这种基于表达式模板的数组的行为能像一个左值,即可写。而且类似于这样的表达式模板的组件和A_Mult等类似,唯一区别在于它提供了下标运算符的const版本和non-const版本,并返回一个左值引用,用decltype(auto)处理数组下标很方便
// exprtmpl/exprops3.hpp
template<typename T, typename A1, typename A2>
class A_Subscript {
public:
// constructor initializes references to operands
A_Subscript (A1 const& a, A2 const& b)
: a1(a), a2(b) {
}
// process subscription when value requested
decltype(auto) operator[] (size_t idx) const {
return a1[a2[idx]];
}
T& operator[] (size_t idx) {
return a1[a2[idx]];
}
// size is size of inner array
std::size_t size() const {
return a2.size();
}
private:
A1 const& a1; // reference to first operand
A2 const& a2; // reference to second operand
};
- 针对这种运用子集语义的、扩展的下标运算符,要为Array模板定义额外的下标运算符,其中一个定义如下(还需要一个针对const的对应版本)
// exprtmpl/exprops4.hpp
template<typename T, typename R>
template<typename T2, typename R2> inline
Array<T, A_Subscript<T, R, R2>>
Array<T, R>::operator[](Array<T2, R2> const& b) {
return Array<T, A_Subscript<T, R, R2>>
(A_Subscript<T, R, R2>(*this, b));
}
表达式模板的性能与约束
- 表达式模板可以提高数组操作性能,跟踪其行为可以发现许多很小的内联函数互相调用,在调用堆栈还分配了许多小的表达式模板对象,因此编译器必须执行完整的内联小对象和去除小对象操作,来产生性能上和手写循环媲美的代码
- 表达式模板并没有解决所有涉及数组数值操作的问题,如对x = A*x这种矩阵-vector乘法,x是一个大小为n的vector,A是一个n*n矩阵。问题在于临时变量的使用不可避免,因为最终结果的每个元素都依赖于最初x的每个元素,而表达式模板会在一次计算上马上更新x的首个元素,计算下一个元素时用到这个已更新的元素就改变了原来的数组
- 但针对x = A*y,如果x和y不互为别名,就不需要一个临时对象,这表示必须在运行期知道操作数是否为别名关系,反过来又说明必须生成一个用于表示表达式树的运行期结构,而不是在表达式模板的类型中编码这棵树
网友评论