美文网首页
区间dp入门

区间dp入门

作者: 乔治yuanbo | 来源:发表于2020-12-12 09:11 被阅读0次

    题目:
    洛谷P1040加分二叉树
    大意是给一个正整数序列,它是一棵二叉树的中序遍历结果;一棵树的加分定义为左子树加分x右子树加分+根的值,若某个子树为空,则它的加分是1。
    要求输出这棵树的最大加分,和前序遍历结果。
    样例输入:
    5
    5 7 1 2 10
    输出:
    145
    3 1 2 4 5

    首先得明白题,对于一个中序遍历序列,二叉树的形态有很多种可能,题目要求找出加分最大的那种形态。对于二叉树,最重要的是确定树根,中序序列[1,n]的树根可能有n种,分别是(1,2,3...n),枚举每一个可能的树根,递归算出左右子树最大加分,最后将加分最大的定为树根。

    一、暴力搜索

    #include <cstdio>
    #define MAXN 30
    typedef long long ll;
    using namespace std;
    
    rs表示 区间[i,j]当加分最大时,树的根
    int n, a[MAXN+1], rs[MAXN][MAXN];
    //搜索从l到r这个区间内的最大加分,并确定根节点
    ll dfs(int l, int r){
        if (r == l - 1){//这是空树
            return 1;
        } else if (l == r){//叶子
            return a[l];
        } else if (r == l + 1){//序列只有两个数,由于是中序,根必然是l
            rs[l][r] = l;
            return a[l] + a[r];
        } else {//枚举每一个数,使其成为根节点,
            ll max = -1;
            for(int root = l; root <= r; root++){
                ll ans = dfs(l, root - 1) * dfs(root + 1, r) + a[root];
                if (ans > max){
                    rs[l][r] = root;
                    max = ans;
                }
            }
            return max;
        }
    }
    void print(int l, int r){//前序遍历输出
        if (l > r){//空树,什么都不输出
            return;
        } else if (l == r){//叶子,直接输出
            printf("%d ", l);
            return;
        }
        int root = rs[l][r];
        printf("%d ", root);//先输出根
        print(l, root - 1);//输出左子树
        print(root + 1, r);//输出右子树
    }
    int main(){
        scanf("%d", &n);
        for(int i = 1; i <= n; i++){
            scanf("%d", a + i);
        }
        ll ans = dfs(1, n);
        printf("%lld\n", ans);
        print(1, n);
    
        return 0;
    }
    

    二、记忆化,暴力搜索有很多重复计算,比如样例5个数的序列,1为根,[2,5]为右子树时,需要计算[3,5];2为根时,也需要计算[3,5],很慢,加上记忆化就快了,把算过的存起来,下次直接用。

    #include <cstdio>
    #define MAXN 30
    typedef long long ll;
    using namespace std;
    
    int n, a[MAXN+1], rs[MAXN][MAXN];
    ll f[MAXN][MAXN];//记忆化数组
    ll dfs(int l, int r){
        if (r == l - 1){
            return 1;
        } else if (l == r){
            return a[l];
        }
        if (f[l][r]){
            return f[l][r];
        }
        if (r == l + 1){
            rs[l][r] = l;
            return f[l][r] = a[l] + a[r];
        } else {
            ll max = -1;
            for(int root = l; root <= r; root++){
                ll ans = dfs(l, root - 1) * dfs(root + 1, r) + a[root];
                if (ans > max){
                    rs[l][r] = root;
                    max = ans;
                }
            }
            return f[l][r] = max;
        }
    }
    void print(int l, int r){
        if (l > r){
            return;
        } else if (l == r){
            printf("%d ", l);
            return;
        }
        int root = rs[l][r];
        printf("%d ", root);
        print(l, root - 1);
        print(root + 1, r);
    }
    int main(){
        scanf("%d", &n);
        for(int i = 1; i <= n; i++){
            scanf("%d", a + i);
        }
        ll ans = dfs(1, n);
        printf("%lld\n", ans);
        print(1, n);
    
        return 0;
    }
    

    三、区间dp

    记忆化搜索非常好,但此题对于学习区间dp很有帮助,必须要掌握。
    根据记忆化搜索方法,状态和转移方程已经浮现出来,f[i][j]表示中序序列从i到j的最大加分,转移方程为:
    f[i][j] = max(f[i][k-1] * f[k+1][j] + a[k]), k属于[i,j]
    最终结果就是f[1][n]
    dp一般自底向上求解,先解决小问题,然后逐步放大,显然f[i][i]是叶子结点,它的加分是已知的,即f[i][i]=a[i],f[i][i-1]表示空树,加分是1,把求解区间长度逐渐放大不断扩大规模,即枚举区间长度,从1到n,最后算出f[1][n]。

    #include <cstdio>
    #define MAXN 30
    typedef long long ll;
    using namespace std;
    
    int n, rs[MAXN+1][MAXN+1];
    ll f[MAXN+1][MAXN+1];
    void print(int l, int r){
        if (l > r){
            return;
        } else if (l == r){
            printf("%d ", l);
            return;
        }
        int root = rs[l][r];
        printf("%d ", root);
        print(l, root - 1);
        print(root + 1, r);
    }
    int main(){
        scanf("%d", &n);
        for(int i = 1; i <= n; i++){
            scanf("%d", &f[i][i]);
            f[i][i-1] = 1;
            rs[i][i] = i;
        }
        for(int i = 1; i < n; i++){
            f[i][i+1] = f[i][i] + f[i+1][i+1];
            rs[i][i+1] = i;
        }
        for(int len = 3; len <= n; len++){
            for(int i = 1; ; i++){
                if (i + len - 1 > n) break;
                //枚举每一个根,根从i到i+len-1
                for(int j = 0; j < len; j++){
                    int r = i + j;
                    ll v = f[i][r-1] * f[r+1][i+len-1] + f[r][r];
                    if (f[i][i+len-1] < v){
                        f[i][i+len-1] = v;
                        rs[i][i+len-1] = r;
                    }
                }
            }
        }
        printf("%lld\n", f[1][n]);
        print(1, n);
    
        return 0;
    }
    

    相关文章

      网友评论

          本文标题:区间dp入门

          本文链接:https://www.haomeiwen.com/subject/zsnrgktx.html