美文网首页C++ Templates
【C++ Templates(25)】表达式模板

【C++ Templates(25)】表达式模板

作者: downdemo | 来源:发表于2018-07-09 09:09 被阅读60次
    • 表达式模板是为了支持一种数值数组的类引入的技术。如希望可以像内置类型一样对数组进行下列操作
    Array<double> x(1000), y(1000);
    ...
    x = 1.2*x + x*y;
    
    • 要获得高效率,同时支持上面这种紧凑写法,则需要通过表达式模板来完成。表达式模板和metaprogramming是互补的,metaprogramming主要用于小的、大小固定的数组,表达式模板则用于能在运行期确定大小、中等大小的数组

    临时变量和分割循环

    • 了解表达式模板前,先看一种简单的数值数组操作的模板实现
    // exprtmpl/sarray1.hpp
    
    #include <cstddef>
    #include <cassert>
    
    template<typename T>
    class SArray { // simple array
    public:
        // create array with initial size
        explicit SArray (size_t s)
         : storage(new T[s]), storage_size(s) {
            init();
        }
    
        // copy constructor
        SArray (SArray<T> const& orig)
         : storage(new T[orig.size()]), storage_size(orig.size()) {
            copy(orig);
        }
    
        // destructor: free memory
        ~SArray() {
            delete[] storage;
        }
    
        // assignment operator
        SArray<T>& operator= (SArray<T> const& orig) {
            if (&orig!=this) {
                copy(orig);
            }
            return *this;
        }
    
        // return size
        size_t size() const {
            return storage_size;
        }
    
        // index operator for constants and variables
        T const& operator[] (std::size_t idx) const {
            return storage[idx];
        }
        T& operator[] (std::size_t idx) {
            return storage[idx];
        }
    
    protected:
        // init values with default constructor
        void init() {
            for (std::size_t idx = 0; idx<size(); ++idx) {
                storage[idx] = T();
            }
        }
        // copy values of another array
        void copy (SArray<T> const& orig) {
            assert(size()==orig.size());
            for (std::size_t idx = 0; idx<size(); ++idx) {
                storage[idx] = orig.storage[idx];
            }
        }
    
    private:
        T*     storage;       // storage of the elements
        std::size_t storage_size;  // number of elements
    };
    
    • 数值运算符实现如下
    // exprtmpl/sarrayops1.hpp
    
    // addition of two SArrays
    template<typename T>
    SArray<T> operator+ (SArray<T> const& a, SArray<T> const& b)
    {
        assert(a.size()==b.size());
        SArray<T> result(a.size());
        for (std::size_t k = 0; k<a.size(); ++k) {
            result[k] = a[k]+b[k];
        }
        return result;
    }
    
    // multiplication of two SArrays
    template<typename T>
    SArray<T> operator* (SArray<T> const& a, SArray<T> const& b)
    {
        assert(a.size()==b.size());
        SArray<T> result(a.size());
        for (std::size_t k = 0; k<a.size(); ++k) {
            result[k] = a[k]*b[k];
        }
        return result;
    }
    
    // multiplication of scalar and SArray
    template<typename T>
    SArray<T> operator* (T const& s, SArray<T> const& a)
    {
        SArray<T> result(a.size());
        for (std::size_t k = 0; k<a.size(); ++k) {
            result[k] = s*a[k];
        }
        return result;
    }
    
    • 上面这些运算符足够完成例子中的表达式计算了
    // exprtmpl/sarray1.cpp
    
    #include "sarray1.hpp"
    #include "sarrayops1.hpp"
    
    int main()
    {
        SArray<double> x(1000), y(1000);
        ...
        x = 1.2*x + x*y;
    }
    
    • 但上面的实现非常低效,有两方面原因
      • 每个运算符操作(除了赋值运算符)至少要生成一个临时数组,例子中编译器不执行任何附加的临时拷贝,也至少会生成3个大小为1000的临时数组
      • 运算符程序每次使用都要求对实参和结果数组进行额外遍历,例子中只生成一个SArray对象,需要读取6000次double值,写入4000次double值
    tmp1 = 1.2*x; // 循环1000次元素操作,再加上创建和删除tmp1
    tmp2 = x*y // 循环1000次元素操作,再加上创建和删除tmp2
    tmp3 = tmp1+tmp2; // 循环1000次读写操作,再加上创建和删除tmp3
    x = tmp3; // 1000次读操作和写操作
    
    • 对于元素很多的数组,没有足够的内存容纳这些临时对象,每个数值数组程序库的实现都会面临这个问题,因此通常使用computed assignments(如+=、*=)来代替前面的赋值运算符,这样不需要创建任何临时对象
    // exprtmpl/sarrayops2.hpp
    
    // additive assignment of SArray
    template<class T>
    SArray<T>& SArray<T>::operator+= (SArray<T> const& b)
    {
        assert(size()==orig.size());
        for (std::size_t k = 0; k<size(); ++k) {
            (*this)[k] += b[k];
        }
        return *this;
    }
    
    // multiplicative assignment of SArray
    template<class T>
    SArray<T>& SArray<T>::operator*= (SArray<T> const& b)
    {
        assert(size()==orig.size());
        for (std::size_t k = 0; k<size(); ++k) {
            (*this)[k] *= b[k];
        }
        return *this;
    }
    
    // multiplicative assignment of scalar
    template<class T>
    SArray<T>& SArray<T>::operator*= (T const& s)
    {
        for (std::size_t k = 0; k<size(); ++k) {
            (*this)[k] *= s;
        }
        return *this;
    }
    
    • 用这些运算符改写前面的例子
    // exprtmpl/sarray2.cpp
    
    #include "sarray2.hpp"
    #include "sarrayops1.hpp"
    #include "sarrayops2.hpp"
    
    int main()
    {
        SArray<double> x(1000), y(1000);
        //...
        // process x = 1.2*x + x*y
        SArray<double> tmp(x);
        tmp *= y;
        x *= 1.2;
        x += tmp;
    }
    
    • 显然使用computed assignments有下列明显缺点
      • 符号变得不雅观
      • 仍要创建一个非必要的局部变量tmp
      • 循环被分割成多个操作,这里是3个,意味着要对double进行6000次读和4000次写操作
    • 实际上所期望的操作是针对数组每个下标,只对表达式进行一次理想循环
    int main()
    {
        SArray<double> x(1000), y(1000);
        ...
        for (int idx = 0; idx<x.size(); ++idx) {
            x[idx] = 1.2*x[idx] + x[idx]*y[idx];
        }
    }
    
    • 现在就不需要局部数组了,每次迭代只需要进行两次读(x[idx] and y[idx])和一次写(x[k]),在手工循环中总共只需要2000次读和1000次写操作。最后我们希望既能得到高性能,又不需要手工循环,需要用到下面的技术,让代码更优雅且减少错误产生

    在模板实参中编码表达式

    • 对前面的问题,一个很好的解决方法是,直到看到整个表达式(上例中为调用赋值运算符时)才对表达式各部分求值,因此求值前必须记录每个对象和该对象上的每个操作,这些操作在编译期已经确定,因此可以用模板实参编码
    1.2*x + x*y;
    
    • 上面的表达式中,1.2*x不是一个新的数组,而是一个用于表示x的每个值都乘1.2的对象,x*y表示x的每个元素乘y相应元素,最后需要结果数组值时才进行计算,即只是先存储用于后来求值的一种表示,并没有真正进行计算。把上面表达式转为一个具有如下类型的对象
    A_Add<A_Mult<A_Scalar<double>,Array<double>>,
          A_Mult<Array<double>,Array<double>>>
    
    • 对这个表达式,存在一个前序语法树的表示方法
    Tree representation of expression 1.2*x+x*y

    表达式模板的操作数

    • 为了完整表示整个表达式,一方面在每个A_Add和A_Mult对象中,必须存储指向实参的引用,另一方面在A_Scalar对象中需要记录这个表示放大倍数的值或引用,下面是对这些操作数的可行定义
    // exprtmpl/exprops1.hpp
    
    #include <cstddef>
    #include <cassert>
    
    // include helper class traits template to select whether to refer to an
    // ``expression template node'' either ``by value'' or ``by reference.''
    #include "exprops1a.hpp"
    
    // class for objects that represent the addition of two operands
    template <typename T, typename OP1, typename OP2>
    class A_Add {
    private:
        typename A_Traits<OP1>::ExprRef op1;    // first operand
        typename A_Traits<OP2>::ExprRef op2;    // second operand
    
    public: 
        // constructor initializes references to operands
        A_Add (OP1 const& a, OP2 const& b)
         : op1(a), op2(b) {
        }
    
        // compute sum when value requested
        T operator[] (size_t idx) const {
            return op1[idx] + op2[idx];
        }
    
        // size is maximum size
        std::size_t size() const {
            assert (op1.size()==0 || op2.size()==0
                    || op1.size()==op2.size());
            return op1.size()!=0 ? op1.size() : op2.size();
        }
    };
    
    // class for objects that represent the multiplication of two operands
    template <typename T, typename OP1, typename OP2>
    class A_Mult {
    private:
        typename A_Traits<OP1>::ExprRef op1;    // first operand
        typename A_Traits<OP2>::ExprRef op2;    // second operand
    
    public:
        // constructor initializes references to operands
        A_Mult (OP1 const& a, OP2 const& b)
         : op1(a), op2(b) {
        }
    
        // compute product when value requested
        T operator[] (size_t idx) const {
            return op1[idx] * op2[idx];
        }
    
        // size is maximum size
        std::size_t size() const {
            assert (op1.size()==0 || op2.size()==0
                    || op1.size()==op2.size());
            return op1.size()!=0 ? op1.size() : op2.size();
        }
    };
    
    • 由代码可见,增加了下标运算符和查询容量大小的操作,从而可以根据该对象子节点(见上图)的相应操作来计算该节点的大小和每个元素的值。对于只涉及到数组的操作,结果数组的大小是其中某个操作数的大小,然而对于同时涉及到数组和scalar的操作,结果数组的大小就是操作数数组的大小。为了区分数组操作数和scalar操作数,假定scalar大小为0,模板A_Scalar定义如下
    // exprtmpl/exprscalar.hpp
    
    // class for objects that represent scalars
    template <typename T>
    class A_Scalar {
    private:
        T const& s;  // value of the scalar
    
    public:
        // constructor initializes value
        constexpr A_Scalar (T const& v)
         : s(v) {
        }
    
        // for index operations the scalar is the value of each element
        constexpr T const& operator[] (std::size_t) const {
            return s;
        }
    
        // scalars have zero as size
        constexpr std::size_t size() const {
            return 0;
        };
    };
    
    • 由上面代码可以看出,A_Scalar模板也提供了一个索引运算符。在表达式内部,A_Scalar表示的是一个每个索引都相应相同scalar值的数组。运算符类还使用了一个辅助类A_Traits来定义操作数成员
    typename A_Traits<OP1>::ExprRef op1; // first operand
    typename A_Traits<OP2>::ExprRef op2; // second operand
    
    • 这种做法是必要的,通常可以把这些操作数声明为引用类型,因为大多数局部节点在顶层表达式绑定,生命期能延续到完整表达式的求值。唯一例外的是A_Scalar节点,它在运算符函数内部绑定,不能一直存在到完整表达式的求值,因此需要传值拷贝而非传引用,即需要以下性质的成员
    // 通常情况下是const&
    OP1 const& op1; // refer to first operand by reference
    OP2 const& op2; // refer to second operand by reference
    // 但对scalar值是普通值
    OP1 op1; // refer to first operand by value
    OP2 op2; // refer to second operand by value
    
    • 这就要用到trait class,它定义了一个针对大多数const引用的基本模板,同时也定义了针对scalar的特化
    // exprtmpl/exprops1a.hpp
    
    /* helper traits class to select how to refer to an ''expression template node''
     * - in general: by reference
     * - for scalars: by value
     */
    
    template <typename T> class A_Scalar;
    
    // primary template
    template <typename T>
    class A_Traits {
    public:
        using ExprRef = T const&;     // type to refer to is constant reference
    };
    
    // partial specialization for scalars
    template <typename T>
    class A_Traits<A_Scalar<T>> {
    public:
        using ExprRef = A_Scalar<T>;  // type to refer to is ordinary value
    };
    
    • 另外,如果A_Scalar对象引用的是顶层定义的scalar,对这些scalar也可以用引用类型

    Array类型

    • 既然能使用轻量级的表达式模板对表达式编码,下面创建一个Array类型,它既能针对占用实际内存的数组,同时也适用于表达式模板。接口设计上应该与占用存储空间的真实数组相似,同时要与基于数组的表达式具有相同表示,Array模板声明如下
    template<typename T, typename Rep = SArray<T>>
    class Array;
    
    • 上面代码中,Rep类型要么是SArray(前提是Array必须是一个占用实际存储空间的数组),要么是一个用于编码表达式的嵌套template-id,如A_Add和A_Mult。我们将用同一种方式处理这两种途径产生的Array实例化体,以简化后期编码,如果用A_Mult等类型替换Rep,某些成员不能被实例化,但实际中Array模板的定义不需要声明用于区分上面两种情况(即SArray和template-id)的特化,下面是一个定义,下面用到了decltype(auto),处理数组下标很方便
    // exprtmpl/exprarray.hpp
    
    #include <cstddef>
    #include <cassert>
    #include "sarray1.hpp"
    
    template <typename T, typename Rep = SArray<T>>
    class Array {
    private:
        Rep expr_rep;   // (access to) the data of the array
    
    public:
        // create array with initial size
        explicit Array (std::size_t s)
         : expr_rep(s) {
        }
    
        // create array from possible representation
        Array (Rep const& rb)
         : expr_rep(rb) {
        }
    
        // assignment operator for same type
        Array& operator= (Array const& b) { 
            assert(size()==b.size());
            for (std::size_t idx = 0; idx<b.size(); ++idx) {
                expr_rep[idx] = b[idx];
            }
            return *this;
        }
    
        // assignment operator for arrays of different type
        template<typename T2, typename Rep2>
        Array& operator= (Array<T2, Rep2> const& b) { 
            assert(size()==b.size());
            for (std::size_t idx = 0; idx<b.size(); ++idx) {
                expr_rep[idx] = b[idx];
            }
            return *this;
        }
    
        // size is size of represented data
        std::size_t size() const {
            return expr_rep.size();
        }
    
        // index operator for constants and variables
        decltype(auto) operator[] (std::size_t idx) const {
            assert(idx<size());
            return expr_rep[idx];
        }
        T& operator[] (std::size_t idx) {
            assert(idx<size());
            return expr_rep[idx];
        }
    
        // return what the array currently represents
        Rep const& rep() const {
            return expr_rep;
        }
    
        Rep& rep() {
            return expr_rep;
        }
    };
    
    • 这里的许多操作都只是简单委托给所含的Rep对象,但拷贝另一个数组时必须考虑,另一个数组是否基于表达式模板,因此需要根据Rep的表示对拷贝运算符进行参数化,即声明针对两种不同情况的赋值运算符

    运算符

    • 目前只是实现了用于代表运算符的、针对数值Array模板的运算符操作(如A_Add),但没实现运算符本身(如+)。正如前面所说,这些运算符只是用于代表表达式模板对象,实际上并不对结果数组求值。显然对每个普通的二元运算符,必须实现三个版本,即array-array,array-scalar, scalar-array,如为了计算前面的表达式初始值需要用到下面的运算符
    // exprtmpl/exprops2.hpp
    
    // addition of two Arrays
    template <typename T, typename R1, typename R2>
    Array<T,A_Add<T,R1,R2>>
    operator+ (Array<T,R1> const& a, Array<T,R2> const& b) {
        return Array<T,A_Add<T,R1,R2>>
               (A_Add<T,R1,R2>(a.rep(),b.rep()));
    }
    
    // multiplication of two Arrays
    template <typename T, typename R1, typename R2>
    Array<T, A_Mult<T,R1,R2>>
    operator* (Array<T,R1> const& a, Array<T,R2> const& b) {
        return Array<T,A_Mult<T,R1,R2>>
               (A_Mult<T,R1,R2>(a.rep(), b.rep()));
    }
    
    // multiplication of scalar and Array
    template <typename T, typename R2>
    Array<T, A_Mult<T,A_Scalar<T>,R2>>
    operator* (T const& s, Array<T,R2> const& b) {
        return Array<T,A_Mult<T,A_Scalar<T>,R2>>
               (A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep()));
    }
    
    • 这些运算符声明看起来复杂,实际上函数做的工作不多,如对两个数组的加法运算符,首先生成一个用于A_Add<>对象用于表示运算符和操作数
    A_Add<T,R1,R2>(a.rep(),b.rep())
    
    • 并把这个对象封装到一个数组中,从而可以借助数组来操作这个运算结果,实际上其他对象也是这样处理的
    return Array<T,A_Add<T,R1,R2>> (... );
    
    • 对scalar乘法使用了A_Scalar模板创建A_Mult对象
    A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep())
    
    • 并对它进行了封装
    return Array<T,A_Mult<T,A_Scalar<T>,R2>> (... );
    
    • 其他二元运算符实现类似,也可以用宏来声明这些运算符,从而只需要使用较少的代码

    回顾

    • 对前面的例子,现在进行一个自顶向下的回顾。下面是要分析的代码
    int main()
    {
        Array<double> x(1000), y(1000);
        ...
        x = 1.2*x + x*y;
    }
    
    • 由于x和y的定义中省略了Rep实参,所以该参数使用默认值SArray<double>,因此x和y是占用真实内存的数组,也就是说它们不只是用于记录操作。当解析表达式1.2*x + x*y时,编译器首先应用最左边的*,它是一个scalar-array运算符,于是重载解析规则选择operator*的scalar-array形式
    template<typename T, typename R2>
    Array<T, A_Mult<T,A_Scalar<T>,R2>>
    operator* (T const& s, Array<T,R2> const& b) {
        return Array<T,A_Mult<T,A_Scalar<T>,R2>>
               (A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s), b.rep()));
    }
    
    • 其中操作数类型是double和Array<double, SArray<double>>,因此实际的结果类型是Array<double, A_Mult<double, A_Scalar<double>, SArray<double>>>,而结果值是一个构造自double值1.2的A_Scalar<double>对象和一个表示对象x的SArrayr<double>对象
    • 接着对第二个乘法求值,x*y是一个array-array操作,使用相应的operator*
    template <typename T, typename R1, typename R2>
    Array<T, A_Mult<T,R1,R2>>
    operator* (Array<T,R1> const& a, Array<T,R2> const& b) {
        return Array<T,A_Mult<T,R1,R2>>
            (A_Mult<T,R1,R2>(a.rep(), b.rep()));
    }
    
    • 而两个操作数的类型都是Array<double, SArray<double>>,因此结果类型为Array<double, A_Mult<double, SArray<double>, SArray<double>>>,这次A_Mult所封装的两个参数对象都引用了一个SArray<double>,代表一个用于表示x对象,另一个用于表示y对象
    • 最后对+运算符求值,依然是array-array操作 ,操作数类型是上面推断的类型,因此调用针对array-array的oprator+
    template <typename T, typename R1, typename R2>
    Array<T,A_Add<T,R1,R2>>
    operator+ (Array<T,R1> const& a, Array<T,R2> const& b) {
        return Array<T,A_Add<T,R1,R2> >
            (A_Add<T,R1,R2>(a.rep(),b.rep()));
    }
    
    • 其中用double替换T,R1替换为A_Mult<double, A_Scalar<double>, SArray<double> >,R2替换为A_Mult<double, SArray<double>, SArray<double> >,最终赋值运算符右边的表达式类型为
    Array<double,
        A_Add<double,
              A_Mult<double, A_Scalar<double>, SArray<double>>,
              A_Mult<double, SArray<double>, SArray<double>>>>
    
    • 这个类型与Array模板的赋值运算符模板进行匹配
    template <typename T, typename Rep = SArray<T> >
    class Array {
    public:
    ...
        // assignment operator for arrays of different type
        template<typename T2, typename Rep2>
        Array& operator= (Array<T2, Rep2> const& b) {
            assert(size()==b.size());
            for (std::size_t idx = 0; idx<b.size(); ++idx) {
                expr_rep[idx] = b[idx];
            }
            return *this;
        }
        ...
    };
    
    • 其中赋值运算符将会使用右边Array(即b)的下标运算符来计算目标数组x的每个元素,右边Array的实际类型为
    A_Add<double,
          A_Mult<double, A_Scalar<double>, SArray<double>>,
          A_Mult<double, SArray<double>, SArray<double>>>>
    
    • 如果仔细跟踪这个下标操作,对一个给定的下标x将得到
    (1.2*x[idx]) + (x[idx]*y[idx])
    
    • 这正是所期望计算的表达式

    表达式模板赋值

    • 对于一个Rep实参基于A_Mult或A_Add表达式模板的数组,是不能为该数组实例化写操作的(即编写a+b=c的式子毫无意义),但可以编写其他的表达式模板从而能对这些表达式模板的结果赋值,如以具有整数值数组为下标的索引操作通常会涉及到子集的选择,即x[y] = 2*x[y]等价于
    for (std::size_t idx = 0; idx<y.size(); ++idx) {
        x[y[idx]] = 2*x[y[idx]];
    }
    
    • 为了使上面这种写法可行,必须令这种基于表达式模板的数组的行为能像一个左值,即可写。而且类似于这样的表达式模板的组件和A_Mult等类似,唯一区别在于它提供了下标运算符的const版本和non-const版本,并返回一个左值引用,用decltype(auto)处理数组下标很方便
    // exprtmpl/exprops3.hpp
    
    template<typename T, typename A1, typename A2>
    class A_Subscript {
    public:
        // constructor initializes references to operands
        A_Subscript (A1 const& a, A2 const& b)
         : a1(a), a2(b) {
        }
    
        // process subscription when value requested
        decltype(auto) operator[] (size_t idx) const {
            return a1[a2[idx]];
        }
        T& operator[] (size_t idx) {
            return a1[a2[idx]];
        }
    
        // size is size of inner array
        std::size_t size() const {
            return a2.size();
        }
    private:
        A1 const& a1;    // reference to first operand
        A2 const& a2;    // reference to second operand
    };
    
    • 针对这种运用子集语义的、扩展的下标运算符,要为Array模板定义额外的下标运算符,其中一个定义如下(还需要一个针对const的对应版本)
    // exprtmpl/exprops4.hpp
    
    template<typename T, typename R>
      template<typename T2, typename R2> inline
    Array<T, A_Subscript<T, R, R2>>
    Array<T, R>::operator[](Array<T2, R2> const& b) {
        return Array<T, A_Subscript<T, R, R2>>
               (A_Subscript<T, R, R2>(*this, b));
    } 
    

    表达式模板的性能与约束

    • 表达式模板可以提高数组操作性能,跟踪其行为可以发现许多很小的内联函数互相调用,在调用堆栈还分配了许多小的表达式模板对象,因此编译器必须执行完整的内联小对象和去除小对象操作,来产生性能上和手写循环媲美的代码
    • 表达式模板并没有解决所有涉及数组数值操作的问题,如对x = A*x这种矩阵-vector乘法,x是一个大小为n的vector,A是一个n*n矩阵。问题在于临时变量的使用不可避免,因为最终结果的每个元素都依赖于最初x的每个元素,而表达式模板会在一次计算上马上更新x的首个元素,计算下一个元素时用到这个已更新的元素就改变了原来的数组
    • 但针对x = A*y,如果x和y不互为别名,就不需要一个临时对象,这表示必须在运行期知道操作数是否为别名关系,反过来又说明必须生成一个用于表示表达式树的运行期结构,而不是在表达式模板的类型中编码这棵树

    相关文章

      网友评论

        本文标题:【C++ Templates(25)】表达式模板

        本文链接:https://www.haomeiwen.com/subject/xvhzuftx.html