美文网首页
C++ 编译期有理数库 ratio

C++ 编译期有理数库 ratio

作者: 奇点创客 | 来源:发表于2021-08-06 16:59 被阅读0次

std::ratio 简介

std::ratio 的实现

// ratio.hpp
#pragma once
#ifndef _MLIBCXX_RATIO
#define _MLIBCXX_RATIO 1

#include "type_traits.h"  // for integral_constant, bool_constant, void_t
#include "cstdint.h"      // for intmax_t, INTMAX_MAX

namespace mystd {

// 获取有符号整数的符号,正数返回 1,负数返回 -1
template<intmax_t Ax>
struct _sign : integral_constant<intmax_t, (Ax < 0 ? -1 : 1)> {};

template<intmax_t Ax>
inline constexpr intmax_t _sign_v = _sign<Ax>::value;

// 获取有符号整数的绝对值
template<intmax_t Ax>
struct _abs : integral_constant<intmax_t, Ax * _sign_v<Ax>> {};

template<intmax_t Ax>
inline constexpr intmax_t _abs_v = _abs<Ax>::value;

// 计算两个数的最大公约数
template<intmax_t Ax, intmax_t Bx>
struct _gcd : _gcd<Bx, (Ax % Bx)> {};

template<intmax_t Ax>
struct _gcd<Ax, 0> : integral_constant<intmax_t, _abs_v<Ax>> {};

template<intmax_t Bx>
struct _gcd<0, Bx> : integral_constant<intmax_t, _abs_v<Bx>> {};

template<intmax_t Ax, intmax_t Bx>
inline constexpr intmax_t _gcd_v = _gcd<Ax, Bx>::value;

// 安全的乘法(防止溢出)
template<intmax_t Ax, intmax_t Bx, bool sfinae = false,
         bool Good = (_abs_v<Ax> <= (INTMAX_MAX / (Bx == 0 ? 1 : _abs_v<Bx>)))>
struct _safe_mult : integral_constant<intmax_t, Ax * Bx> {};

template<intmax_t Ax, intmax_t Bx, bool sfinae>
struct _safe_mult<Ax, Bx, sfinae, false> 
{ static_assert(sfinae, "integer arithmetic overflow"); };

template<intmax_t Ax, intmax_t Bx>
inline constexpr intmax_t _safe_mult_v = _safe_mult<Ax, Bx>::value;

// 安全的加法(防止溢出) 
template<intmax_t Ax, intmax_t Bx, bool Good, bool AlsoGood>
struct _safe_addx : integral_constant<intmax_t, Ax + Bx> {};

template<typename Tp>
inline constexpr bool always_false = false;

template<intmax_t Ax, intmax_t Bx> 
struct _safe_addx<Ax, Bx, false, false>
{ static_assert(always_false<_safe_addx>, "integer arithmetic overflow"); };

template<intmax_t Ax, intmax_t Bx>
struct _safe_add : _safe_addx<Ax, Bx, _sign_v<Ax> != _sign_v<Bx>,
                            (_abs_v<Ax> <= INTMAX_MAX - _abs_v<Bx>)
                            >::type {};

template<intmax_t Ax, intmax_t Bx>
inline constexpr intmax_t safe_add_v = _safe_add<Ax, Bx>::value;

// STRUCT TEMPLATE ratio (C++11)
template<intmax_t Nx, intmax_t Dx = 1>
struct ratio {
    static_assert(Dx != 0, "denominator cannot be zero");
    static_assert(-INTMAX_MAX <= Nx, "numerator too negative");
    static_assert(-INTMAX_MAX <= Dx, "denominator too negative");

    static constexpr intmax_t num = 
        _sign_v<Nx> * _sign_v<Dx> * _abs_v<Nx> / _gcd_v<Nx, Dx>;

    static constexpr intmax_t den = _abs_v<Dx> / _gcd_v<Nx, Dx>;

    using type = ratio<num, den>;
};

// VARIABLE TEMPLATE _is_ratio_v
template<typename Tp>
inline constexpr bool _is_ratio_v = false;
template<intmax_t Nx, intmax_t Dx>
inline constexpr bool _is_ratio_v<ratio<Nx, Dx>> = true;

template<typename Tp>
concept ratio_c = _is_ratio_v<Tp>;

// ALIAS TEMPLATE ratio_add
template<ratio_c Rx1, ratio_c Rx2> 
struct _ratio_add {
    static constexpr intmax_t Nx1 = Rx1::num;
    static constexpr intmax_t Dx1 = Rx1::den;
    static constexpr intmax_t Nx2 = Rx2::num;
    static constexpr intmax_t Dx2 = Rx2::den;

    static constexpr intmax_t _Gx = _gcd_v<Dx1, Dx2>;
    static constexpr intmax_t _Nx = safe_add_v<_safe_mult_v<Nx1, Dx2 / _Gx>,
                                               _safe_mult_v<Nx2, Dx1 / _Gx>>;
    static constexpr intmax_t _Dx = _safe_mult_v<Dx1, Dx2 / _Gx>;

    using type = typename ratio<_Nx, _Dx>::type;
};

template<ratio_c Rx1, ratio_c Rx2>
using ratio_add = typename _ratio_add<Rx1, Rx2>::type;

// ALIAS TEMPLATE ratio_subtract
template<ratio_c Rx1, ratio_c Rx2>
struct _ratio_subtract {
    static constexpr intmax_t Nx2 = Rx2::num;
    static constexpr intmax_t Dx2 = Rx2::den;

    using type = ratio_add<Rx1, ratio<-Nx2, Dx2>>;
};

template<ratio_c Rx1, ratio_c Rx2>
using ratio_subtract = typename _ratio_subtract<Rx1, Rx2>::type;

// ALIAS TEMPLATE ratio_multiply
template<ratio_c Rx1, ratio_c Rx2> 
struct _ratio_multiply {
    static constexpr intmax_t Nx1 = Rx1::num;
    static constexpr intmax_t Dx1 = Rx1::den;
    static constexpr intmax_t Nx2 = Rx2::num;
    static constexpr intmax_t Dx2 = Rx2::den;

    static constexpr intmax_t _Gx = _gcd_v<Nx1, Dx2>;
    static constexpr intmax_t _Gy = _gcd_v<Nx2, Dx1>;

    using _Num = _safe_mult<Nx1 / _Gx, Nx2 / _Gy, true>;
    using _Den = _safe_mult<Dx1 / _Gy, Dx2 / _Gx, true>;  
};

template<ratio_c Rx1, ratio_c Rx2, bool sfinae, typename = void>
struct _ratio_multiply_sfinae 
{ static_assert(sfinae, "integer arithmetic overflow"); };

template<ratio_c Rx1, ratio_c Rx2, bool sfinae>
struct _ratio_multiply_sfinae<Rx1, Rx2, sfinae,
    void_t<typename _ratio_multiply<Rx1, Rx2>::_Num::type,
           typename _ratio_multiply<Rx1, Rx2>::_Den::type>>
{
    using type = ratio<_ratio_multiply<Rx1, Rx2>::_Num::value,
                       _ratio_multiply<Rx1, Rx2>::_Den::value>;
};

template<ratio_c Rx1, ratio_c Rx2>
using ratio_multiply = typename _ratio_multiply_sfinae<Rx1, Rx2, false>::type;     

// ALIAS TEMPLATE ratio_divide
template<ratio_c Rx1, ratio_c Rx2>
struct _ratio_divide {
    static constexpr intmax_t Nx2 = Rx2::num;
    static constexpr intmax_t Dx2 = Rx2::den;
    using _Rx2_inverse = ratio<Dx2, Nx2>;
};

template<ratio_c Rx1, ratio_c Rx2, bool sfinae = true>
using _ratio_divide_sfinae = typename 
    _ratio_multiply_sfinae<Rx1, 
                           typename _ratio_divide<Rx1, Rx2>::_Rx2_inverse, 
                           sfinae>::type;

template<ratio_c Rx1, ratio_c Rx2>
using ratio_divide = _ratio_divide_sfinae<Rx1, Rx2, false>;

// STRUCT TEMPLATE ratio_equal
template<ratio_c Rx1, ratio_c Rx2>
struct ratio_equal : 
    bool_constant<Rx1::num == Rx2::num && Rx1::den == Rx2::den> {};

template<ratio_c Rx1, ratio_c Rx2>
inline constexpr bool ratio_equal_v = ratio_equal<Rx1, Rx2>::value;

// STRUCT TEMPLATE ratio_not_equal
template<ratio_c Rx1, ratio_c Rx2>
struct ratio_not_equal : bool_constant<!ratio_equal_v<Rx1, Rx2>> {};

template<ratio_c Rx1, ratio_c Rx2>
inline constexpr bool ratio_not_equal_v = ratio_not_equal<Rx1, Rx2>::value;

// STRUCT TEMPLATE ratio_less
struct _big_uint128 {
    uint64_t _upper;
    uint64_t _lower;  

    constexpr bool operator<(const _big_uint128 _rhs) const noexcept { 
        if (_upper != _rhs._upper) 
            return _upper < _rhs._upper;

        return _lower < _rhs._lower;
    }
};
// multiply two 64-bit integers into a 128-bit integer, Knuth's algorithm M
constexpr _big_uint128 _big_multiply(const uint64_t _left_factor, 
                                     const uint64_t _right_factor) noexcept
{ 
    const uint64_t _left_low   = _left_factor & 0xFFFF'FFFFULL;
    const uint64_t _left_high  = _left_factor >> 32;
    const uint64_t _right_low  = _right_factor & 0xFFFF'FFFFULL;
    const uint64_t _right_high = _right_factor >> 32;

    uint64_t _temp             = _left_low * _right_low;
    const uint64_t _lower32    = _temp & 0xFFFF'FFFFULL;
    uint64_t _carry            = _temp >> 32;

    _temp                      = _left_low * _right_high + _carry;
    const uint64_t _mid_lower  = _temp & 0xFFFF'FFFFULL;
    const uint64_t _mid_upper  = _temp >> 32;

    _temp  = _left_high * _right_low + _mid_lower;
    _carry = _temp >> 32;

    return { _left_high * _right_high + _mid_upper + _carry,
            (_temp << 32) + _lower32 };
}

constexpr bool _ratio_less(const int64_t Nx1, const int64_t Dx1, 
                           const int64_t Nx2, const int64_t Dx2) noexcept 
{
    if (Nx1 >= 0 && Nx2 >= 0) {
        return _big_multiply(static_cast<uint64_t>(Nx1), 
                             static_cast<uint64_t>(Dx2))
             < 
               _big_multiply(static_cast<uint64_t>(Nx2), 
                             static_cast<uint64_t>(Dx1));
    }

    if (Nx1 < 0 && Nx2 < 0) {
        return _big_multiply(static_cast<uint64_t>(-Nx2), 
                             static_cast<uint64_t>(Dx1))
             < 
               _big_multiply(static_cast<uint64_t>(-Nx1), 
                             static_cast<uint64_t>(Dx2));
    }

    return Nx1 < Nx2;
}

template <ratio_c Rx1, ratio_c Rx2>
struct ratio_less : 
    bool_constant<_ratio_less(Rx1::num, Rx1::den, Rx2::num, Rx2::den)> {};

template <ratio_c Rx1, ratio_c Rx2>
inline constexpr bool ratio_less_v = ratio_less<Rx1, Rx2>::value;

// STRUCT TEMPLATE ratio_less_equal
template <ratio_c Rx1, ratio_c Rx2>
struct ratio_less_equal : bool_constant<!ratio_less_v<Rx2, Rx1>> {};

template <ratio_c Rx1, ratio_c Rx2>
inline constexpr bool ratio_less_equal_v = ratio_less_equal<Rx1, Rx2>::value;

// STRUCT TEMPLATE ratio_greater
template <ratio_c Rx1, ratio_c Rx2>
struct ratio_greater : ratio_less<Rx2, Rx1>::type {};

template <ratio_c Rx1, ratio_c Rx2>
inline constexpr bool ratio_greater_v = ratio_greater<Rx1, Rx2>::value;

// STRUCT TEMPLATE ratio_greater_equal
template <ratio_c Rx1, ratio_c Rx2>
struct ratio_greater_equal : bool_constant<!ratio_less_v<Rx1, Rx2>> {};

template <ratio_c Rx1, ratio_c Rx2>
inline constexpr bool ratio_greater_equal_v = 
    ratio_greater_equal<Rx1, Rx2>::value;

// SI TYPEDEFS
using atto  = ratio<1, 1000000000000000000LL>;  // 10^-18 阿(托)
using femto = ratio<1, 1000000000000000LL>;     // 10^-15 飞(母托)
using pico  = ratio<1, 1000000000000LL>;        // 10^-12 皮(可)
using nano  = ratio<1, 1000000000>;             // 10^-9  纳(诺)
using micro = ratio<1, 1000000>;                // 10^-6  微
using milli = ratio<1, 1000>;                   // 10^-3  毫
using centi = ratio<1, 100>;                    // 10^-2  厘
using deci  = ratio<1, 10>;                     // 10^-1  分
using deca  = ratio<10, 1>;                     // 10^1   十
using hecto = ratio<100, 1>;                    // 10^2   百
using kilo  = ratio<1000, 1>;                   // 10^3   千
using mega  = ratio<1000000, 1>;                // 10^6   兆
using giga  = ratio<1000000000, 1>;             // 10^9   吉(咖)
using tera  = ratio<1000000000000LL, 1>;        // 10^12  太(拉)
using peta  = ratio<1000000000000000LL, 1>;     // 10^15  拍(它)
using exa   = ratio<1000000000000000000LL, 1>;  // 10^18  艾(可萨)
} // namespace mystd


#endif // _MLIBCXX_RATIO

测试 ratio

// test_ratio.cpp
#include <iostream>
#include "ratio.h"

int main()
{
    using std::cout, std::endl;
    using namespace mystd;

    // ratio
    using two_third = ratio<2, 3>;
    using one_sixth = ratio<1, 6>;

    cout << two_third::num << "/" << two_third::den << " + "
         << one_sixth::num << "/" << two_third::den << " = ";

    // ratio_add
    using sum = ratio_add<two_third, one_sixth>;
    cout << sum::num << "/" << sum::den << endl;

    // ratio_subtract
    using diff = ratio_subtract<two_third, one_sixth>;
    cout << "2/3 - 1/6 = " << diff::num << '/' << diff::den << '\n';

    // ratio_multiply
    using product = ratio_multiply<two_third, one_sixth>;
    cout << "2/3 * 1/6 = " << product::num << '/' << product::den << '\n';

    // ratio_divide
    using quotient = ratio_divide<two_third, one_sixth>;
    cout << "2/3 / 1/6 = " << quotient::num << '/' << quotient::den << '\n';

    // ratio_equal_v
    if (ratio_equal_v<ratio<2, 3>, ratio<4, 6>>) 
        cout << "2/3 == 4/6\n";
    else 
        cout << "2/3 != 4/6\n";
    
    // ratio_not_equal_v
    if (ratio_not_equal_v<ratio<2, 3>, ratio<1, 3>>) 
        cout << "2/3 != 1/3\n";
    else 
        cout << "2/3 == 1/3\n";

    if (ratio_less_v<ratio<23, 37>, ratio<57, 90>>) 
        cout << "23/37 < 57/90\n";

    // ratio_less_equal
    static_assert(ratio_less_equal<ratio<1, 2>, ratio<3, 4>>::value, "1/2 <= 3/4");

    if (ratio_less_equal<ratio<10, 11>, ratio<11, 12>>::value)
        cout << "10/11 <= 11/12" "\n";

    // ratio_less_equal_v (C++17 起)
    static_assert(ratio_less_equal_v<ratio<10, 11>, ratio<11, 12>>);

    if constexpr (ratio_less_equal_v<ratio<10, 11>, ratio<11, 12>>) 
        cout << "11/12 <= 12/13" "\n";

    // ratio_less_equal
    static_assert(ratio_greater<ratio<3, 4>, ratio<1, 2>>::value, "3/4 > 1/2");

    if (ratio_greater<ratio<11, 12>, ratio<10, 11>>::value) 
        cout << "11/12 > 10/11" "\n";
    

    // ratio_greater_v (C++17 起)
    static_assert(ratio_greater_v<ratio<12, 13>, ratio<11, 12>>);

    if constexpr (ratio_greater_v<ratio<12, 13>, ratio<11, 12>>) 
        cout << "12/13 > 11/12" "\n";

     static_assert(ratio_greater_equal<ratio<2, 3>, ratio<2, 3>>::value, "2/3 >= 2/3");
 
    if (ratio_greater_equal<ratio<2,1>, ratio<1, 2>>::value) 
        cout << "2/1 >= 1/2" "\n";
    
    if (ratio_greater_equal<ratio<1,2>, ratio<1, 2>>::value) 
        cout << "1/2 >= 1/2" "\n";
    
    // C++17 起
    static_assert(ratio_greater_equal_v<ratio<999'999, 1'000'000>, 
                                        ratio<999'998, 999'999>> );
 
    if constexpr (ratio_greater_equal_v<ratio<999'999, 1'000'000>, 
                                        ratio<999'998, 999'999>>) 
        cout << "999'999/1'000'000 >= 999'998/999'999" "\n";
   
    if constexpr (ratio_greater_equal_v<ratio<999'999, 1'000'000>, 
                                        ratio<999'999, 1'000'000>>) 
        cout << "999'999/1'000'000 >= 999'999/1'000'000" "\n";

}

结果输出

相关文章

网友评论

      本文标题:C++ 编译期有理数库 ratio

      本文链接:https://www.haomeiwen.com/subject/syqivltx.html