美文网首页C++ Templates
【C++ Templates(25)】表达式模板

【C++ Templates(25)】表达式模板

作者: downdemo | 来源:发表于2018-07-09 09:09 被阅读60次
  • 表达式模板是为了支持一种数值数组的类引入的技术。如希望可以像内置类型一样对数组进行下列操作
Array<double> x(1000), y(1000);
...
x = 1.2*x + x*y;
  • 要获得高效率,同时支持上面这种紧凑写法,则需要通过表达式模板来完成。表达式模板和metaprogramming是互补的,metaprogramming主要用于小的、大小固定的数组,表达式模板则用于能在运行期确定大小、中等大小的数组

临时变量和分割循环

  • 了解表达式模板前,先看一种简单的数值数组操作的模板实现
// exprtmpl/sarray1.hpp

#include <cstddef>
#include <cassert>

template<typename T>
class SArray { // simple array
public:
    // create array with initial size
    explicit SArray (size_t s)
     : storage(new T[s]), storage_size(s) {
        init();
    }

    // copy constructor
    SArray (SArray<T> const& orig)
     : storage(new T[orig.size()]), storage_size(orig.size()) {
        copy(orig);
    }

    // destructor: free memory
    ~SArray() {
        delete[] storage;
    }

    // assignment operator
    SArray<T>& operator= (SArray<T> const& orig) {
        if (&orig!=this) {
            copy(orig);
        }
        return *this;
    }

    // return size
    size_t size() const {
        return storage_size;
    }

    // index operator for constants and variables
    T const& operator[] (std::size_t idx) const {
        return storage[idx];
    }
    T& operator[] (std::size_t idx) {
        return storage[idx];
    }

protected:
    // init values with default constructor
    void init() {
        for (std::size_t idx = 0; idx<size(); ++idx) {
            storage[idx] = T();
        }
    }
    // copy values of another array
    void copy (SArray<T> const& orig) {
        assert(size()==orig.size());
        for (std::size_t idx = 0; idx<size(); ++idx) {
            storage[idx] = orig.storage[idx];
        }
    }

private:
    T*     storage;       // storage of the elements
    std::size_t storage_size;  // number of elements
};
  • 数值运算符实现如下
// exprtmpl/sarrayops1.hpp

// addition of two SArrays
template<typename T>
SArray<T> operator+ (SArray<T> const& a, SArray<T> const& b)
{
    assert(a.size()==b.size());
    SArray<T> result(a.size());
    for (std::size_t k = 0; k<a.size(); ++k) {
        result[k] = a[k]+b[k];
    }
    return result;
}

// multiplication of two SArrays
template<typename T>
SArray<T> operator* (SArray<T> const& a, SArray<T> const& b)
{
    assert(a.size()==b.size());
    SArray<T> result(a.size());
    for (std::size_t k = 0; k<a.size(); ++k) {
        result[k] = a[k]*b[k];
    }
    return result;
}

// multiplication of scalar and SArray
template<typename T>
SArray<T> operator* (T const& s, SArray<T> const& a)
{
    SArray<T> result(a.size());
    for (std::size_t k = 0; k<a.size(); ++k) {
        result[k] = s*a[k];
    }
    return result;
}
  • 上面这些运算符足够完成例子中的表达式计算了
// exprtmpl/sarray1.cpp

#include "sarray1.hpp"
#include "sarrayops1.hpp"

int main()
{
    SArray<double> x(1000), y(1000);
    ...
    x = 1.2*x + x*y;
}
  • 但上面的实现非常低效,有两方面原因
    • 每个运算符操作(除了赋值运算符)至少要生成一个临时数组,例子中编译器不执行任何附加的临时拷贝,也至少会生成3个大小为1000的临时数组
    • 运算符程序每次使用都要求对实参和结果数组进行额外遍历,例子中只生成一个SArray对象,需要读取6000次double值,写入4000次double值
tmp1 = 1.2*x; // 循环1000次元素操作,再加上创建和删除tmp1
tmp2 = x*y // 循环1000次元素操作,再加上创建和删除tmp2
tmp3 = tmp1+tmp2; // 循环1000次读写操作,再加上创建和删除tmp3
x = tmp3; // 1000次读操作和写操作
  • 对于元素很多的数组,没有足够的内存容纳这些临时对象,每个数值数组程序库的实现都会面临这个问题,因此通常使用computed assignments(如+=、*=)来代替前面的赋值运算符,这样不需要创建任何临时对象
// exprtmpl/sarrayops2.hpp

// additive assignment of SArray
template<class T>
SArray<T>& SArray<T>::operator+= (SArray<T> const& b)
{
    assert(size()==orig.size());
    for (std::size_t k = 0; k<size(); ++k) {
        (*this)[k] += b[k];
    }
    return *this;
}

// multiplicative assignment of SArray
template<class T>
SArray<T>& SArray<T>::operator*= (SArray<T> const& b)
{
    assert(size()==orig.size());
    for (std::size_t k = 0; k<size(); ++k) {
        (*this)[k] *= b[k];
    }
    return *this;
}

// multiplicative assignment of scalar
template<class T>
SArray<T>& SArray<T>::operator*= (T const& s)
{
    for (std::size_t k = 0; k<size(); ++k) {
        (*this)[k] *= s;
    }
    return *this;
}
  • 用这些运算符改写前面的例子
// exprtmpl/sarray2.cpp

#include "sarray2.hpp"
#include "sarrayops1.hpp"
#include "sarrayops2.hpp"

int main()
{
    SArray<double> x(1000), y(1000);
    //...
    // process x = 1.2*x + x*y
    SArray<double> tmp(x);
    tmp *= y;
    x *= 1.2;
    x += tmp;
}
  • 显然使用computed assignments有下列明显缺点
    • 符号变得不雅观
    • 仍要创建一个非必要的局部变量tmp
    • 循环被分割成多个操作,这里是3个,意味着要对double进行6000次读和4000次写操作
  • 实际上所期望的操作是针对数组每个下标,只对表达式进行一次理想循环
int main()
{
    SArray<double> x(1000), y(1000);
    ...
    for (int idx = 0; idx<x.size(); ++idx) {
        x[idx] = 1.2*x[idx] + x[idx]*y[idx];
    }
}
  • 现在就不需要局部数组了,每次迭代只需要进行两次读(x[idx] and y[idx])和一次写(x[k]),在手工循环中总共只需要2000次读和1000次写操作。最后我们希望既能得到高性能,又不需要手工循环,需要用到下面的技术,让代码更优雅且减少错误产生

在模板实参中编码表达式

  • 对前面的问题,一个很好的解决方法是,直到看到整个表达式(上例中为调用赋值运算符时)才对表达式各部分求值,因此求值前必须记录每个对象和该对象上的每个操作,这些操作在编译期已经确定,因此可以用模板实参编码
1.2*x + x*y;
  • 上面的表达式中,1.2*x不是一个新的数组,而是一个用于表示x的每个值都乘1.2的对象,x*y表示x的每个元素乘y相应元素,最后需要结果数组值时才进行计算,即只是先存储用于后来求值的一种表示,并没有真正进行计算。把上面表达式转为一个具有如下类型的对象
A_Add<A_Mult<A_Scalar<double>,Array<double>>,
      A_Mult<Array<double>,Array<double>>>
  • 对这个表达式,存在一个前序语法树的表示方法
Tree representation of expression 1.2*x+x*y

表达式模板的操作数

  • 为了完整表示整个表达式,一方面在每个A_Add和A_Mult对象中,必须存储指向实参的引用,另一方面在A_Scalar对象中需要记录这个表示放大倍数的值或引用,下面是对这些操作数的可行定义
// exprtmpl/exprops1.hpp

#include <cstddef>
#include <cassert>

// include helper class traits template to select whether to refer to an
// ``expression template node'' either ``by value'' or ``by reference.''
#include "exprops1a.hpp"

// class for objects that represent the addition of two operands
template <typename T, typename OP1, typename OP2>
class A_Add {
private:
    typename A_Traits<OP1>::ExprRef op1;    // first operand
    typename A_Traits<OP2>::ExprRef op2;    // second operand

public: 
    // constructor initializes references to operands
    A_Add (OP1 const& a, OP2 const& b)
     : op1(a), op2(b) {
    }

    // compute sum when value requested
    T operator[] (size_t idx) const {
        return op1[idx] + op2[idx];
    }

    // size is maximum size
    std::size_t size() const {
        assert (op1.size()==0 || op2.size()==0
                || op1.size()==op2.size());
        return op1.size()!=0 ? op1.size() : op2.size();
    }
};

// class for objects that represent the multiplication of two operands
template <typename T, typename OP1, typename OP2>
class A_Mult {
private:
    typename A_Traits<OP1>::ExprRef op1;    // first operand
    typename A_Traits<OP2>::ExprRef op2;    // second operand

public:
    // constructor initializes references to operands
    A_Mult (OP1 const& a, OP2 const& b)
     : op1(a), op2(b) {
    }

    // compute product when value requested
    T operator[] (size_t idx) const {
        return op1[idx] * op2[idx];
    }

    // size is maximum size
    std::size_t size() const {
        assert (op1.size()==0 || op2.size()==0
                || op1.size()==op2.size());
        return op1.size()!=0 ? op1.size() : op2.size();
    }
};
  • 由代码可见,增加了下标运算符和查询容量大小的操作,从而可以根据该对象子节点(见上图)的相应操作来计算该节点的大小和每个元素的值。对于只涉及到数组的操作,结果数组的大小是其中某个操作数的大小,然而对于同时涉及到数组和scalar的操作,结果数组的大小就是操作数数组的大小。为了区分数组操作数和scalar操作数,假定scalar大小为0,模板A_Scalar定义如下
// exprtmpl/exprscalar.hpp

// class for objects that represent scalars
template <typename T>
class A_Scalar {
private:
    T const& s;  // value of the scalar

public:
    // constructor initializes value
    constexpr A_Scalar (T const& v)
     : s(v) {
    }

    // for index operations the scalar is the value of each element
    constexpr T const& operator[] (std::size_t) const {
        return s;
    }

    // scalars have zero as size
    constexpr std::size_t size() const {
        return 0;
    };
};
  • 由上面代码可以看出,A_Scalar模板也提供了一个索引运算符。在表达式内部,A_Scalar表示的是一个每个索引都相应相同scalar值的数组。运算符类还使用了一个辅助类A_Traits来定义操作数成员
typename A_Traits<OP1>::ExprRef op1; // first operand
typename A_Traits<OP2>::ExprRef op2; // second operand
  • 这种做法是必要的,通常可以把这些操作数声明为引用类型,因为大多数局部节点在顶层表达式绑定,生命期能延续到完整表达式的求值。唯一例外的是A_Scalar节点,它在运算符函数内部绑定,不能一直存在到完整表达式的求值,因此需要传值拷贝而非传引用,即需要以下性质的成员
// 通常情况下是const&
OP1 const& op1; // refer to first operand by reference
OP2 const& op2; // refer to second operand by reference
// 但对scalar值是普通值
OP1 op1; // refer to first operand by value
OP2 op2; // refer to second operand by value
  • 这就要用到trait class,它定义了一个针对大多数const引用的基本模板,同时也定义了针对scalar的特化
// exprtmpl/exprops1a.hpp

/* helper traits class to select how to refer to an ''expression template node''
 * - in general: by reference
 * - for scalars: by value
 */

template <typename T> class A_Scalar;

// primary template
template <typename T>
class A_Traits {
public:
    using ExprRef = T const&;     // type to refer to is constant reference
};

// partial specialization for scalars
template <typename T>
class A_Traits<A_Scalar<T>> {
public:
    using ExprRef = A_Scalar<T>;  // type to refer to is ordinary value
};
  • 另外,如果A_Scalar对象引用的是顶层定义的scalar,对这些scalar也可以用引用类型

Array类型

  • 既然能使用轻量级的表达式模板对表达式编码,下面创建一个Array类型,它既能针对占用实际内存的数组,同时也适用于表达式模板。接口设计上应该与占用存储空间的真实数组相似,同时要与基于数组的表达式具有相同表示,Array模板声明如下
template<typename T, typename Rep = SArray<T>>
class Array;
  • 上面代码中,Rep类型要么是SArray(前提是Array必须是一个占用实际存储空间的数组),要么是一个用于编码表达式的嵌套template-id,如A_Add和A_Mult。我们将用同一种方式处理这两种途径产生的Array实例化体,以简化后期编码,如果用A_Mult等类型替换Rep,某些成员不能被实例化,但实际中Array模板的定义不需要声明用于区分上面两种情况(即SArray和template-id)的特化,下面是一个定义,下面用到了decltype(auto),处理数组下标很方便
// exprtmpl/exprarray.hpp

#include <cstddef>
#include <cassert>
#include "sarray1.hpp"

template <typename T, typename Rep = SArray<T>>
class Array {
private:
    Rep expr_rep;   // (access to) the data of the array

public:
    // create array with initial size
    explicit Array (std::size_t s)
     : expr_rep(s) {
    }

    // create array from possible representation
    Array (Rep const& rb)
     : expr_rep(rb) {
    }

    // assignment operator for same type
    Array& operator= (Array const& b) { 
        assert(size()==b.size());
        for (std::size_t idx = 0; idx<b.size(); ++idx) {
            expr_rep[idx] = b[idx];
        }
        return *this;
    }

    // assignment operator for arrays of different type
    template<typename T2, typename Rep2>
    Array& operator= (Array<T2, Rep2> const& b) { 
        assert(size()==b.size());
        for (std::size_t idx = 0; idx<b.size(); ++idx) {
            expr_rep[idx] = b[idx];
        }
        return *this;
    }

    // size is size of represented data
    std::size_t size() const {
        return expr_rep.size();
    }

    // index operator for constants and variables
    decltype(auto) operator[] (std::size_t idx) const {
        assert(idx<size());
        return expr_rep[idx];
    }
    T& operator[] (std::size_t idx) {
        assert(idx<size());
        return expr_rep[idx];
    }

    // return what the array currently represents
    Rep const& rep() const {
        return expr_rep;
    }

    Rep& rep() {
        return expr_rep;
    }
};
  • 这里的许多操作都只是简单委托给所含的Rep对象,但拷贝另一个数组时必须考虑,另一个数组是否基于表达式模板,因此需要根据Rep的表示对拷贝运算符进行参数化,即声明针对两种不同情况的赋值运算符

运算符

  • 目前只是实现了用于代表运算符的、针对数值Array模板的运算符操作(如A_Add),但没实现运算符本身(如+)。正如前面所说,这些运算符只是用于代表表达式模板对象,实际上并不对结果数组求值。显然对每个普通的二元运算符,必须实现三个版本,即array-array,array-scalar, scalar-array,如为了计算前面的表达式初始值需要用到下面的运算符
// exprtmpl/exprops2.hpp

// addition of two Arrays
template <typename T, typename R1, typename R2>
Array<T,A_Add<T,R1,R2>>
operator+ (Array<T,R1> const& a, Array<T,R2> const& b) {
    return Array<T,A_Add<T,R1,R2>>
           (A_Add<T,R1,R2>(a.rep(),b.rep()));
}

// multiplication of two Arrays
template <typename T, typename R1, typename R2>
Array<T, A_Mult<T,R1,R2>>
operator* (Array<T,R1> const& a, Array<T,R2> const& b) {
    return Array<T,A_Mult<T,R1,R2>>
           (A_Mult<T,R1,R2>(a.rep(), b.rep()));
}

// multiplication of scalar and Array
template <typename T, typename R2>
Array<T, A_Mult<T,A_Scalar<T>,R2>>
operator* (T const& s, Array<T,R2> const& b) {
    return Array<T,A_Mult<T,A_Scalar<T>,R2>>
           (A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep()));
}
  • 这些运算符声明看起来复杂,实际上函数做的工作不多,如对两个数组的加法运算符,首先生成一个用于A_Add<>对象用于表示运算符和操作数
A_Add<T,R1,R2>(a.rep(),b.rep())
  • 并把这个对象封装到一个数组中,从而可以借助数组来操作这个运算结果,实际上其他对象也是这样处理的
return Array<T,A_Add<T,R1,R2>> (... );
  • 对scalar乘法使用了A_Scalar模板创建A_Mult对象
A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep())
  • 并对它进行了封装
return Array<T,A_Mult<T,A_Scalar<T>,R2>> (... );
  • 其他二元运算符实现类似,也可以用宏来声明这些运算符,从而只需要使用较少的代码

回顾

  • 对前面的例子,现在进行一个自顶向下的回顾。下面是要分析的代码
int main()
{
    Array<double> x(1000), y(1000);
    ...
    x = 1.2*x + x*y;
}
  • 由于x和y的定义中省略了Rep实参,所以该参数使用默认值SArray<double>,因此x和y是占用真实内存的数组,也就是说它们不只是用于记录操作。当解析表达式1.2*x + x*y时,编译器首先应用最左边的*,它是一个scalar-array运算符,于是重载解析规则选择operator*的scalar-array形式
template<typename T, typename R2>
Array<T, A_Mult<T,A_Scalar<T>,R2>>
operator* (T const& s, Array<T,R2> const& b) {
    return Array<T,A_Mult<T,A_Scalar<T>,R2>>
           (A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep()));
}
  • 其中操作数类型是double和Array<double, SArray<double>>,因此实际的结果类型是Array<double, A_Mult<double, A_Scalar<double>, SArray<double>>>,而结果值是一个构造自double值1.2的A_Scalar<double>对象和一个表示对象x的SArrayr<double>对象
  • 接着对第二个乘法求值,x*y是一个array-array操作,使用相应的operator*
template <typename T, typename R1, typename R2>
Array<T, A_Mult<T,R1,R2>>
operator* (Array<T,R1> const& a, Array<T,R2> const& b) {
    return Array<T,A_Mult<T,R1,R2>>
        (A_Mult<T,R1,R2>(a.rep(), b.rep()));
}
  • 而两个操作数的类型都是Array<double, SArray<double>>,因此结果类型为Array<double, A_Mult<double, SArray<double>, SArray<double>>>,这次A_Mult所封装的两个参数对象都引用了一个SArray<double>,代表一个用于表示x对象,另一个用于表示y对象
  • 最后对+运算符求值,依然是array-array操作 ,操作数类型是上面推断的类型,因此调用针对array-array的oprator+
template <typename T, typename R1, typename R2>
Array<T,A_Add<T,R1,R2>>
operator+ (Array<T,R1> const& a, Array<T,R2> const& b) {
    return Array<T,A_Add<T,R1,R2> >
        (A_Add<T,R1,R2>(a.rep(),b.rep()));
}
  • 其中用double替换T,R1替换为A_Mult<double, A_Scalar<double>, SArray<double> >,R2替换为A_Mult<double, SArray<double>, SArray<double> >,最终赋值运算符右边的表达式类型为
Array<double,
    A_Add<double,
          A_Mult<double, A_Scalar<double>, SArray<double>>,
          A_Mult<double, SArray<double>, SArray<double>>>>
  • 这个类型与Array模板的赋值运算符模板进行匹配
template <typename T, typename Rep = SArray<T> >
class Array {
public:
...
    // assignment operator for arrays of different type
    template<typename T2, typename Rep2>
    Array& operator= (Array<T2, Rep2> const& b) {
        assert(size()==b.size());
        for (std::size_t idx = 0; idx<b.size(); ++idx) {
            expr_rep[idx] = b[idx];
        }
        return *this;
    }
    ...
};
  • 其中赋值运算符将会使用右边Array(即b)的下标运算符来计算目标数组x的每个元素,右边Array的实际类型为
A_Add<double,
      A_Mult<double, A_Scalar<double>, SArray<double>>,
      A_Mult<double, SArray<double>, SArray<double>>>>
  • 如果仔细跟踪这个下标操作,对一个给定的下标x将得到
(1.2*x[idx]) + (x[idx]*y[idx])
  • 这正是所期望计算的表达式

表达式模板赋值

  • 对于一个Rep实参基于A_Mult或A_Add表达式模板的数组,是不能为该数组实例化写操作的(即编写a+b=c的式子毫无意义),但可以编写其他的表达式模板从而能对这些表达式模板的结果赋值,如以具有整数值数组为下标的索引操作通常会涉及到子集的选择,即x[y] = 2*x[y]等价于
for (std::size_t idx = 0; idx<y.size(); ++idx) {
    x[y[idx]] = 2*x[y[idx]];
}
  • 为了使上面这种写法可行,必须令这种基于表达式模板的数组的行为能像一个左值,即可写。而且类似于这样的表达式模板的组件和A_Mult等类似,唯一区别在于它提供了下标运算符的const版本和non-const版本,并返回一个左值引用,用decltype(auto)处理数组下标很方便
// exprtmpl/exprops3.hpp

template<typename T, typename A1, typename A2>
class A_Subscript {
public:
    // constructor initializes references to operands
    A_Subscript (A1 const& a, A2 const& b)
     : a1(a), a2(b) {
    }

    // process subscription when value requested
    decltype(auto) operator[] (size_t idx) const {
        return a1[a2[idx]];
    }
    T& operator[] (size_t idx) {
        return a1[a2[idx]];
    }

    // size is size of inner array
    std::size_t size() const {
        return a2.size();
    }
private:
    A1 const& a1;    // reference to first operand
    A2 const& a2;    // reference to second operand
};
  • 针对这种运用子集语义的、扩展的下标运算符,要为Array模板定义额外的下标运算符,其中一个定义如下(还需要一个针对const的对应版本)
// exprtmpl/exprops4.hpp

template<typename T, typename R>
  template<typename T2, typename R2> inline
Array<T, A_Subscript<T, R, R2>>
Array<T, R>::operator[](Array<T2, R2> const& b) {
    return Array<T, A_Subscript<T, R, R2>>
           (A_Subscript<T, R, R2>(*this, b));
} 

表达式模板的性能与约束

  • 表达式模板可以提高数组操作性能,跟踪其行为可以发现许多很小的内联函数互相调用,在调用堆栈还分配了许多小的表达式模板对象,因此编译器必须执行完整的内联小对象和去除小对象操作,来产生性能上和手写循环媲美的代码
  • 表达式模板并没有解决所有涉及数组数值操作的问题,如对x = A*x这种矩阵-vector乘法,x是一个大小为n的vector,A是一个n*n矩阵。问题在于临时变量的使用不可避免,因为最终结果的每个元素都依赖于最初x的每个元素,而表达式模板会在一次计算上马上更新x的首个元素,计算下一个元素时用到这个已更新的元素就改变了原来的数组
  • 但针对x = A*y,如果x和y不互为别名,就不需要一个临时对象,这表示必须在运行期知道操作数是否为别名关系,反过来又说明必须生成一个用于表示表达式树的运行期结构,而不是在表达式模板的类型中编码这棵树

相关文章

网友评论

    本文标题:【C++ Templates(25)】表达式模板

    本文链接:https://www.haomeiwen.com/subject/xvhzuftx.html