五大常用算法二(贪心,分治)

作者: fredal | 来源:发表于2016-02-24 00:03 被阅读3875次

    贪心算法

    贪心算法总是作出在当前看来最好的选择。也就是说贪心算法并不从整体最优考虑,它所作出的选择只是在某种意义上的局部最优选择。当然,希望贪心算法得到的最终结果也是整体最优的。虽然贪心算法不能对所有问题都得到整体最优解,但对许多问题它能产生整体最优解,如之前的Dijkstra算法,Prim算法,Kruskal算法。如果不要求绝对最佳答案,那么有时候我们使用简单的贪婪算法生成近似的答案.

    • 贪心与动态规划

    贪心算法和动态规划都需求最优子结构,但是贪心算法是自顶向下方式进行,就是每一步,根据策略得到一个当前最优解。传递到下一步,从而保证每一步都是选择当前最优的。最后得到结果.每一步的最优解都依赖上一部的最优解.你只考虑之前已做出的选择
    而动态规划通常自底向上解各种子问题,每一步,根据策略得到一个更小规模的问题。最后解决最小规模的问题。得到整个问题最优解.全局最优解中一定包含某个局部最优解,但不一定包含前一个局部最优解.你考虑的都是以后的子问题
    经典的还是背包问题,之前的01背包问题我们采用动态规划解决而不能用贪心.但是如果改成部分背包问题呢:
    假如有三件物品,背包可装50磅的物品,物品1重10磅,价值60元;物品2重20磅,价值100元;物品3重30磅,价值120元。你可以选择带走每个物品的全部或一部分,求如何选择可以使背包所装的价值最大?
    注意到不同点是我们可以选择带走一部分,所以使用贪心算法十分自然地想到,先算含金量啊,先把含金量最高的都带完,再带含金量其次的...很容易得到解,带走一件1,一件2,2/3件3...比较简单代码不写.

    • 活动安排问题

    设有n个活动的集合E = {1,2,…,n},其中每个活动都要求使用同一资源,如演讲会场等,而在同一时间内只有一个活动能使用这一资源。每个活i都有一个要求使用该资源的起始时间si和一个结束时间fi,且si < fi 。如果选择了活动i,则它在半开时间区间[si, fi)内占用资源。若区间[si, fi)与区间[sj, fj)不相交,则称活动i与活动j是相容的。也就是说,当si >= fj或sj >= fi时,活动i与活动j相容.怎么尽可能地安排多的相容活动呢?
    设待安排的11个活动的开始时间和结束时间按结束时间的非减序排列如下:

    6
    注意要按结束时间的早晚排列,没排好的话,你可以回去用各种方法自己排.既然贪心么就是越早结束越好,给后面留尽可能多的空间.其次"目光短浅",从排列好的里一个个选,能选一个是一个,别管后面的...
    显然,我们选择到了(1)1-4,(4)5-7,(8)8-11,(11)12-14
    感觉不靠谱么,其实对于这个活动安排问题,贪心算法总能求得的整体最优解,即它最终所确定的相容活动集合A的规模最大。这个结论可以用数学归纳法证明。
    我们还是来代码:
      package com.fredal.structure;
    import java.util.Arrays;
    public class Arrange {
       public static int[] greedyArrangement(int[] start,int[] end){
           int total=start.length;
           int endtime=end[0];//选择的所有活动的最末结束时间
           int[] arrangement=new int[total];
           arrangement[0]=1;//无脑选第一个 最早结束的那个
           int count=1;
           for(int i=0;i<total;i++){
               if(start[i]>endtime){//下一个活动开始时间晚于当前活动结束时间
                   arrangement[count++]=i+1;//活动选中
                   endtime=end[i];//更新结束时间
               }
           }
           return arrangement;
       }    
       public static void main(String[] args) {
           int[] start={1,3,0,5,3,5,6,8,8,2,12};
           int[] end={4,5,6,7,8,9,10,11,12,13,14};
           int[] arrangement=greedyArrangement(start, end);
           for(int i=0;i<arrangement.length;i++){
               if(arrangement[i]!=0)
                   System.out.println("开始时间:"+start[arrangement[i]-1]+",结束时间:"+end[arrangement[i]-1]);
           }
       }
    }
    
    
    • 哈夫曼编码

    哈夫曼编码是广泛地用于数据文件压缩的十分有效的编码方法。其压缩率通常在20%~90%之间。哈夫曼编码算法用字符在文件中出现的频率表来建立一个用0,1串表示各字符的最优表示方式。一个包含100,000个字符的文件,各字符出现频率不同,如下表所示

    7
    我们可以求得对于标准编码位数需要(45+13+12+16+9+5)*3=300,而对于变长码45×1+13×3+12×3+16×3+9×4+5×4=224,压缩了很多...
    首先要讲一讲前缀码:对每一个字符规定一个0,1串作为其代码,并要求任一字符的代码都不是其他字符代码的前缀。这种编码称为前缀码。
    我们可以用二叉树作为前缀码的数据结构:树叶表示给定字符;从树根到树叶的路径当作该字符的前缀码;代码中每一位的0或1分别作为指示某节点到左儿子或右儿子的“路标”.字符只放在树叶上,满二叉树是其基本特征,你知道放法太多了,所以关键问题变成了怎么寻找总价值最小的完全二叉树,即最优前缀码.
    对于该例样本字母表的最优树如下b,位数正是224,图a不是完全二叉树显然不符合:
    8
    那么怎么寻找的呢,就是哈夫曼编码干的事了.构造过程如下:
    假设编码字符集中每一字符c的频率是f(c)。以f为键值的优先队列Q用在贪心选择时有效地确定算法当前要合并的2棵具有最小频率的树。一旦2棵具有最小频率的树合并后,产生一棵新的树,其频率为合并的2棵树的频率之和,并将新树按顺序插入优先队列Q。经过n-1次的合并后,优先队列中只剩下一棵树,即所要求的树T
    9
    代码实现,这里需要使用完全二叉树,发现之前写的没有特别符号要求的,就这里直接实现吧.还有优先队列类,用之前实现过的MyHeap,有需要可以查看堆的实现.
      package com.fredal.structure;
    public class Huffman {
       static MyHeap<Node> heap=new MyHeap<Node>();//堆类
       static class Node implements Comparable<Node>{
           private int weight;//权值 频率
           private String value;//字符
           private Node left;
           private Node right;
           private Node parent;
           private String path;//记录路径
           private boolean isvisited;//是否遍历过
           public Node(int weight, String value) {
               super();
               this.weight = weight;
               this.value = value;
           }
           public int compareTo(Node o) {
               return weight-o.weight;
           }        
       }    
       public static Node bulidHuffman(Node[] nodes){
           for(int i=0;i<nodes.length;i++){
               heap.insert(nodes[i]);
           }
           while(heap.getCurrentSize()>1){
               Node minA = heap.deleteMin();//弹出最小的两个
               Node minB = heap.deleteMin();
               Node sumNode=new Node(minA.weight+minB.weight,minA.value+minB.value);//权值和称为其父节点
               //维护关系
               sumNode.left=minA;
               minA.path="0";//为了方便 直接把路径信息记这儿了
               sumNode.right=minB;
               minB.path="1";
               minA.parent=sumNode;
               minB.parent=sumNode;
               
               heap.insert(sumNode);//插入堆
               
           }
           
           return heap.findMin();//返回最后一个  相当于是完整的树了
       }
       
       public static void printHuffman(Node node){    
           if(node.left!=null && !node.left.isvisited){//遍历左边
               Node left = node.left;
               left.isvisited=true;
               printHuffman(left);
           }
           
           if(node.right!=null && !node.right.isvisited){//遍历右边
               Node right=node.right;
               right.isvisited=true;
               printHuffman(right);
           }
           
           if(node.left==null && node.right==null){//是叶子节点
               StringBuffer sb=new StringBuffer();
               sb.append(node.path);
               System.out.print(node.value+":");
               while(node.parent!=null){
                   node=node.parent;
                   if(node.path!=null)
                     sb.append(node.path);//访问父节点 获得路径信息
               }
               System.out.println(sb.reverse().toString());//输出
               printHuffman(node);//递归 输出下一个叶子节点的编码
           }
       }
       
       public static void main(String[] args) {
           Node[] nodes={
                   new Node(45, "a"),
                   new Node(13, "b"),
                   new Node(12, "c"),
                   new Node(16, "d"),
                   new Node(9, "e"),
                   new Node(5, "f")
           };
           Node root = bulidHuffman(nodes);
           printHuffman(root);
       }
    }
    
    
    • 近似装箱问题

    给定N 项物品,大小为 s1, s2, ..., sN,所有的大小都满足 0 < si < = 1 ;问题是要把这些物品装到最小数目的箱子中去, 已知每个箱子的容量是1个单位;下图显示的是对N项物品的最优装箱方法

    10
    这个问题有两种版本.第一种是联机装箱问题,必须将每一件物品放入一个箱子后才处理下一件物品.另外一种是脱机装箱问题,我们做任何事情都需要等到所有的输入数据被读取后才进行.
    我们先来考虑联机装箱的三种算法,第一种是下项适合算法: 当处理任一物品时,我们检查看他是否还能装进刚刚装进物品的同一个箱子中去.如果能够装进去,那么就把它装入该箱子,否则,就开辟一个新箱子.例子如下:
    11
    我们用代码模拟:
      package com.fredal.structure;
    import java.util.LinkedList;
    public class BinPacking {
       static LinkedList<Box> boxes=new LinkedList<Box>();//存储所有箱子
       static int index=1;
       
       static class Box{//箱子类
           private double remain;//剩余容量
           private LinkedList<Double> values;
           public Box(){
               remain=1;//设容量初始为1
               values=new LinkedList<Double>();//存储箱子中的物品
           }
       }
       //下项适合算法
       public static void nextfit(double[] a){
           for(int i=0;i<a.length;i++){
               if(boxes.peek()==null)
                   boxes.push(new Box());
               Box last = boxes.peek();
               if(last.remain>=a[i]){//装的下就装
                   last.values.add(a[i]);
                   last.remain-=a[i];
               }else{//装不下就开辟新箱子
                   Box nbox=new Box();
                   nbox.values.add(a[i]);
                   nbox.remain-=a[i];
                   boxes.push(nbox);
               }
           }
           
           show(boxes);
       }
       
       //输出显示
       public static void show(LinkedList<Box> boxes){
           while(!boxes.isEmpty()){
               Box box = boxes.removeLast();
               System.out.print("box"+index+++":");
               while(!box.values.isEmpty()){
                   Double value = box.values.removeFirst();
                   System.out.print(value+" ");
               }
               System.out.println();
           }
       }
       
       public static void main(String[] args) {
           double[] a={0.2,0.5,0.4,0.7,0.1,0.3,0.8};
           nextfit(a);
       }
    }
    

    下项算法的性能是线性的,但是在实践中显然是不靠谱的.不需要开辟新箱子的时候开辟了新箱子.接下来讲首次适合算法:依序扫描这些箱子把新的一项物品放入足够盛下它的第一个箱子中.

    12
    代码如下,注意show()函数有点变化的:
      package com.fredal.structure;
    import java.util.Iterator;
    import java.util.LinkedList;
    public class BinPacking {    
       static LinkedList<Box> boxes=new LinkedList<Box>();//存储所有箱子
       static int index=1;
       
       static class Box{//箱子类
           private double remain;//剩余容量
           private LinkedList<Double> values;
           public Box(){
               remain=1;//设容量初始为1
               values=new LinkedList<Double>();//存储箱子中的物品
           }
       }
       //首次适合算法
       public static void firstfit(double[] a){
           for(int i=0;i<a.length;i++){
               boolean flag=false;
               if(boxes.peek()==null)
                   boxes.add(new Box());
               Iterator<Box> it = boxes.iterator();
               while(it.hasNext()){//从头到尾遍历 能装就装
                   Box box = it.next();
                   if(box.remain>=a[i]){
                       box.values.add(a[i]);
                       box.remain-=a[i];
                       flag=true;
                       break;
                   }
               }
               if(!flag){//全部不能装就开辟新的
                   Box nbox=new Box();
                   nbox.values.add(a[i]);
                   nbox.remain-=a[i];
                   boxes.add(nbox);
               }
           }
           
           show(boxes);
       }
       
       //输出显示
       public static void show(LinkedList<Box> boxes){
           while(!boxes.isEmpty()){
               Box box = boxes.removeFirst();
               System.out.print("box"+index+++":");
               while(!box.values.isEmpty()){
                   Double value = box.values.removeFirst();
                   System.out.print(value+" ");
               }
               System.out.println();
           }
       }
       
       public static void main(String[] args) {
           double[] a={0.2,0.5,0.4,0.7,0.1,0.3,0.8};
           firstfit(a);
       }
    }
    
    

    第三个是最佳适合算法,该方法不是吧一项新物品放入所发现的第一个能够容纳它的箱子,而是放到所有箱子中能够容纳它的最满的箱子中.该算法对随记的输入表现的更好

    13
    代码如下,注意输出函数还是变了,并且采用了ArrayList存储:
      package com.fredal.structure;
    import java.util.ArrayList;
    import java.util.Iterator;
    import java.util.LinkedList;
    public class BinPacking {
       static ArrayList<Box> boxes=new ArrayList<Box>();//存储所有箱子
       static int index=1;    
       static class Box{//箱子类
           private double remain;//剩余容量
           private LinkedList<Double> values;
           public Box(){
               remain=1;//设容量初始为1
               values=new LinkedList<Double>();//存储箱子中的物品
           }
       }
       //最佳适合算法
       public static void bestfit(double[] a){
           for(int i=0;i<a.length;i++){
               double remainMin=1+1;//最小的剩余容量 用于寻找最满的箱子 初始化表示比容量大1
               int index=0;//记录箱子编号
               if(boxes.size()==0)
                   boxes.add(new Box());
               for(int j=0;j<boxes.size();j++){
                   Box box=boxes.get(j);
                   if(box.remain>=a[i] && box.remain<remainMin){//从头遍历 如果找到更满的并且能装下的就记录
                       remainMin=box.remain;
                       index=j;
                   }
               }
               if(remainMin<=1){//说明找到了可以装的
                   Box box = boxes.get(index);//装进记录好的最满的箱子
                   box.values.add(a[i]);
                   box.remain-=a[i];
               }else{//找不到可以装的就开辟新的箱子
                   Box nbox=new Box();
                   nbox.values.add(a[i]);
                   nbox.remain-=a[i];
                   boxes.add(nbox);
               }
           }
           
           show(boxes);
       }
       
       //输出显示
       public static void show(ArrayList<Box> boxes){
           for(int i=0;i<boxes.size();i++){
               Box box=boxes.get(i);
               System.out.print("box"+index+++":");
               while(!box.values.isEmpty()){
                   Double value = box.values.removeFirst();
                   System.out.print(value+" ");
               }
               System.out.println();
           }
       }
       
       public static void main(String[] args) {
           double[] a={0.2,0.5,0.4,0.7,0.1,0.3,0.8};
           bestfit(a);
       }
    }
    
    

    接下来是脱机算法,显然脱机算法可以表现得更好.联机算法的问题在于在于将大项物品装箱困难,特别是当他们在输入的晚期出现的时候.于是脱机算法我们可以将各项物品排序,把最大的物品放在最先,此时我们可以应用首次适合算法或最佳适合算法,分别得到“首次适合递减算法” 和 ”最佳适合递减算法”.
    这两种算法是差不多的,我们以首次适合递减算法为例.

    14
    代码如下,我们使用快速排序.注意java中double数有些精度问题,上面的三个算法也会出现可能无法完全装满的问题,我就不去改了.这里改一下
      package com.fredal.structure;
    import java.util.Iterator;
    import java.util.LinkedList;
    public class BinPacking {    
    static LinkedList<Box> boxes=new LinkedList<Box>();//存储所有箱子
    static int index=1;
    static class Box{//箱子类
        private double remain;//剩余容量
        private LinkedList<Double> values;
        public Box(){
            remain=1;//设容量初始为1
            values=new LinkedList<Double>();//存储箱子中的物品
        }
    }
    //首次适合递减算法
    public static void firstfit(double[] a){
        quickSort(a, 0, a.length-1);
        for(int i=0;i<a.length;i++){
            boolean flag=false;
            if(boxes.peek()==null)
                boxes.add(new Box());
            Iterator<Box> it = boxes.iterator();
            while(it.hasNext()){//从头到尾遍历 能装就装
                Box box = it.next();
                if(box.remain>a[i]||Math.abs(box.remain-a[i])<Math.pow(10, -10)){//解决一下精度
                    box.values.add(a[i]);
                    box.remain-=a[i];
                    flag=true;
                    break;
                }
            }
            if(!flag){//全部不能装就开辟新的
                Box nbox=new Box();
                nbox.values.add(a[i]);
                nbox.remain-=a[i];
                boxes.add(nbox);
            }
        }
        show(boxes);
    }
    //快速排序
    public static void quickSort(double[] a,int left,int right){
       if(left<right){//递归出口条件
           int i=left;//左指针
           int j=right;//右指针
           double x=a[left];//选择第一个元素作为标尺
           while(i<j){
               while(i<j && a[j]<=x) j--;//从右向左找第一个大于x的数
               if(i<j) a[i++]=a[j];
               while(i<j && a[i]>x) i++;//从左向右找第一个小于等于x的数
               if(i<j) a[j--]=a[i];
           }
           a[i]=x;//插入标尺
           quickSort(a,left,i-1);//递归左边
           quickSort(a, i+1, right);//递归右边
       }
    }
    
    //输出显示
    public static void show(LinkedList<Box> boxes){
        while(!boxes.isEmpty()){
            Box box = boxes.removeFirst();
            System.out.print("box"+index+++":");
            while(!box.values.isEmpty()){
                Double value = box.values.removeFirst();
                System.out.print(value+" ");
            }
            System.out.println();
        }
    }
    public static void main(String[] args) {
        double[] a={0.2,0.5,0.4,0.7,0.1,0.3,0.8};
        firstfit(a);
    }
    }
    

    分治算法

    分治法是一种很重要的算法。字面上的解释是“分而治之”,就是把一个复杂的问题分成两个或更多的相同或相似的子问题,再把子问题分成更小的子问题……直到最后子问题可以简单的直接求解,原问题的解即子问题的解的合并。这个技巧是很多高效算法的基础.
    分治策略是:对于一个规模为n的问题,若该问题可以容易地解决(比如说规模n较小)则直接解决,否则将其分解为k个规模较小的子问题,这些子问题互相独立且与原问题形式相同,递归地解这些子问题,然后将各子问题的解合并得到原问题的解。这种算法设计策略叫做分治法.
    传统上,含有两个或以上的递归调用的叫做分治法.经典的例子就是快速排序,归并排序
    接下来我们选一些其他的经典例子来分析

    • 大整数乘法

    设有两个大整数相乘,X=61438521,Y=94736407.那么XY=5820464730934047.易知我么的算法需要O(N²)即O(8²)次操作.
    如果我们把X和Y都拆成两半,由最高几位和最低几位组成.那么XL=6143,XR=8521,YL=9473,YR=6470.于是X=XL*10^4+XR,Y=YL*10^4+YR.可以得到
    XY=XL*YL*10^8+(XL*YR+XR*YL)*10^4+XRYR
    显然这个式子就是由4个乘法组成的,每一个都是原问题的一半,而108,104的乘法只是添一些0,于是可以得到递归:T(N)=4T(N/2)+O(N)..
    我们按照主定理(相关资料查阅维基百科),可以求得算法复杂度仍然是O(N²).并没有改进这个问题.
    观察XL*YR+XR*YL,可以分解为(XL-XR)(YR-YL)+XL*YL+XR*YR,我们仅需要算前面一项,后面的两项已经计算过了.于是得到了T(N)=3T(N/2)+O(N).,按照主定理,可得T(N)=O(N^1.59).当然对于每一个乘积我们还可以继续递归下去,一般到四位数就不用递归了.
    我们还是采取代码来模拟过程:

      package com.fredal.structure;
    public class BigMultiply {
       //大整数相乘
       public static String multiply(String x,String y){
           int flag1=0;//x的符号位
           int flag2=0;//y的符号位
           if(x.charAt(0)=='-'){//处理符号
               x=x.substring(1);//先把符号位截掉
               flag1=1;
           }
           if(y.charAt(0)=='-'){
               y=y.substring(1);
               flag2=1;
           }
           
           String flag=(flag1^flag2)==1?"-":"";//相乘即异或之后符号位
           
           if(x.length()<y.length())//保证x的位数更大
               return flag+multiply(y, x);
       
           if(x.length()<=4)
               return flag+Integer.parseInt(x)*Integer.parseInt(y);//少于等于四位数直接计算了
           
           if(x.length()%2==0){//x位数是偶数 就把y补成和x一样长
               while(x.length()>y.length())
                   y="0"+y;
           }else{//x位数不是偶数 就先把x补成偶数 再把y补成和x一样长
               x="0"+x;
               while(x.length()>y.length())
                   y="0"+y;
           }
           
           String xl=x.substring(0,x.length()/2);
           String xr=x.substring(x.length()/2);
           String yl=y.substring(0,y.length()/2);
           String yr=y.substring(y.length()/2);
           
           String D1=minus(xl, xr);//xl-xr
           String D2=minus(yr, yl);//yr-yl
           
           String xlyl=multiply(xl, yl);//xl*yl
           String xryr=multiply(xr, yr);//xr*yr
           
           String D3=add(multiply(D1, D2)+"", add(xlyl, xryr));//D1*D2+Xl*Yl+Xr*Yr
           return flag+add(shift(xlyl, x.length()),add(shift(D3, x.length()/2),xryr));//Xl*Yl*10^n+D3*10^(n/2)+Xr*Yr
       }
       
       //大数相减 带符号处理
       public static String minus(String x,String y){
           int large=compare(x, y);
           String flag=large>=0?"":"-";//加上符号
           if (large==0)
               return "0";
           else if(large>0)//转化成大的减小的
               return minusBigNum(x,y);
           else 
               return flag+minusBigNum(y,x);
       }
       
       //大数相减
       private static String minusBigNum(String x,String y){//大数减小数
           int len=x.length();
           while(len>y.length())
               y="0"+y;
           StringBuilder result=new StringBuilder();
           int flag=0;//表示是否进位
           for(int i=len-1;i>=0;i--){
               int xs=Integer.parseInt(String.valueOf(x.charAt(i)));
               int ys=Integer.parseInt(String.valueOf(y.charAt(i)));
               if(xs+flag>=ys){//别忘了把flag加上
                   result.append(xs-ys+flag);
                   flag=0;
               }else{
                   result.append(10+xs-ys+flag);
                   flag=-1;
               }
           }
           return clearZero(result.reverse().toString());
       }
       
       //大数相加
       public static String add(String x,String y){
           if(x.charAt(0)=='-'){//先处理符号
               x=x.substring(1);
               if(y.charAt(0)=='-'){
                   y=y.substring(1);
                   return "-"+add(x,y);
               }else
                   return minus(y, x);             
           }
           
           if(y.charAt(0)=='-'){
               y=y.substring(1);
               return minus(x, y);
           }
           if(x.length()<y.length())
               return add(y, x);//保证x的位数更大
           
           int len=x.length();
           
           while(len>y.length())//补位使位数相等
               y="0"+y;
           
           StringBuilder result=new StringBuilder();
           int flag=0;//表示是否进位
           for(int i=len-1;i>=0;i--){
               int xs=Integer.parseInt(String.valueOf(x.charAt(i)));
               int ys=Integer.parseInt(String.valueOf(y.charAt(i)));
               if(xs+ys+flag>9){//别忘了把flag加上
                   result.append(xs+ys-10+flag);
                   flag=1;
               }else{
                   result.append(xs+ys+flag);
                   flag=0;
               }
           }
           if(flag!=0)
               result.append(1);
           return clearZero(result.reverse().toString());
       }
       
       //计算10n次方的 就是后面加0
       public static String shift(String x,int n){
           for(int i=0;i<n;i++){
               x+="0";
           }
           return x;
       }
       
       //消除0
       private static String clearZero(String str){
           int i=0;
           while(i<str.length()&&str.charAt(i)=='0'){
               i++;
           }
           return str.substring(i);
       }
       
       //比较两个数的大小
       private static int compare(String x,String y){
           if(x.length()>y.length())
               return 1;
           else if(x.length()<y.length())
               return -1;
           else{
               int index = 0;
               while (index < x.length() && x.charAt(index) == y.charAt(index))
                   index++;
               if (index == x.length())
                   return 0;
               else {
                   return x.charAt(index) > y.charAt(index)? 1 : -1;
               }
           }
       }
       
       public static void main(String[] args) {
           System.out.println(multiply("-61438521", "94736407"));
           System.out.println(multiply("-3124234254543411432432422238221342421",
                   "-2423442342342342342342342323423445345699"));
                   
       }
    }
    
    

    算法并不复杂,但是处理符号什么的还是比较麻烦的,而且加法减法啥的都要自己去实现.

    • Strassen矩阵乘法

    两个矩阵的乘法学过线性代数的都知道怎么求,一般来说复杂度为O(N^3).直接给出标准的算法

      package com.fredal.structure;
    public class MartixMultiply {
       public static int[][] multiply(int[][] a, int[][] b) {
           int n = a.length;
           int[][] c = new int[n][n];
    
           for (int i = 0; i < n; i++)
               // 初始化
               for (int j = 0; j < n; j++)
                   c[i][j] = 0;
    
           for (int i = 0; i < n; i++)
               for (int j = 0; j < n; j++)
                   for (int k = 0; k < n; k++)
                       c[i][j] += a[i][k] * b[k][j];
    
           return c;
       }
    
       public static void main(String[] args) {
           int[][] a = { { 1, 2 }, { 3, 4 } };
           int[][] b = { { 3, 4 }, { 7, 2 } };
           int[][] c = multiply(a, b);
    
           System.out.println(c[0][0] + " " + c[0][1] + " " + c[1][0] + " "
                   + c[1][1]);
       }
    }
    
    

    Strassen提出了算法打破了O(N^3)的屏障.用到分治算法,把矩阵分为4块.

    15 16
    其中:
    17
    可以得到递推关系T(N)=7T(N/2)+O(N²),依据主定理得到解T(N)=O(N^2.81).
    这儿不做出证明,显然这用到了分治法的思想,我们用代码模拟
      package com.fredal.structure;
    public class MartixMultiply {
       public static int[][] StrassenMultiply(int[][] a, int[][] b) {
           int[][] result = new int[a.length][b.length];
           if (a.length == 2)
               return multiply(a, b);// 如果是2阶的 就结束递归 用传统方法
           // a的四个子矩阵
           int[][] A00 = divide(a, 1);
           int[][] A01 = divide(a, 2);
           int[][] A10 = divide(a, 3);
           int[][] A11 = divide(a, 4);
           // b的四个子矩阵
           int[][] B00 = divide(b, 1);
           int[][] B01 = divide(b, 2);
           int[][] B10 = divide(b, 3);
           int[][] B11 = divide(b, 4);
    
           int[][] m1 = StrassenMultiply(addArrays(A00, A11), addArrays(B00, B11));
           int[][] m2 = StrassenMultiply(addArrays(A10, A11), B00);
           int[][] m3 = StrassenMultiply(A00, subArrays(B01, B11));
           int[][] m4 = StrassenMultiply(A11, subArrays(B10, B00));
           int[][] m5 = StrassenMultiply(addArrays(A00, A01), B11);
           int[][] m6 = StrassenMultiply(subArrays(A10, A00), addArrays(B00, B01));
           int[][] m7 = StrassenMultiply(subArrays(A01, A11), addArrays(B10, B11));
    
           int[][] C00 = addArrays(m7, subArrays(addArrays(m1, m4), m5));// m1+m4-m5+m7
           int[][] C01 = addArrays(m3, m5); // m3+m5
           int[][] C10 = addArrays(m2, m4); // m2+m4
           int[][] C11 = addArrays(m6, subArrays(addArrays(m1, m3), m2));// m1+m3-m2+m6
    
           // 将四个矩阵合并起来
           Merge(result, C00, 1);
           Merge(result, C01, 2);
           Merge(result, C10, 3);
           Merge(result, C11, 4);
    
           return result;
       }
    
       // /分割得到子矩阵
       private static int[][] divide(int[][] a, int flag) {
           int[][] result = new int[a.length / 2][a.length / 2];
           switch (flag) {
           case 1:
               for (int i = 0; i < a.length / 2; i++)
                   for (int j = 0; j < a.length / 2; j++)
                       result[i][j] = a[i][j];
               break;
           case 2:
               for (int i = 0; i < a.length / 2; i++)
                   for (int j = a.length / 2; j < a.length; j++)
                       result[i][j - a.length / 2] = a[i][j];
               break;
           case 3:
               for (int i = a.length / 2; i < a.length; i++)
                   for (int j = 0; j < a.length / 2; j++)
                       result[i - a.length / 2][j] = a[i][j];
               break;
           case 4:
               for (int i = a.length / 2; i < a.length; i++)
                   for (int j = a.length / 2; j < a.length; j++)
                       result[i - a.length / 2][j - a.length / 2] = a[i][j];
               break;
           }
           return result;
       }
    
       // 矩阵加法
       private static int[][] addArrays(int[][] a, int[][] b) {
           int[][] result = new int[a.length][a.length];
           for (int i = 0; i < result.length; i++) {
               for (int j = 0; j < result.length; j++) {
                   result[i][j] = a[i][j] + b[i][j];
               }
           }
           return result;
       }
    
       // 矩阵减法
       private static int[][] subArrays(int[][] a, int[][] b) {
           int[][] result = new int[a.length][a.length];
           for (int i = 0; i < result.length; i++) {
               for (int j = 0; j < result.length; j++) {
                   result[i][j] = a[i][j] - b[i][j];
               }
           }
           return result;
       }
    
       // 将b复制到a的指定位置
       private static void Merge(int[][] a, int[][] b, int flag) {
           switch (flag) {
           case 1:
               for (int i = 0; i < a.length / 2; i++)
                   for (int j = 0; j < a.length / 2; j++)
                       a[i][j] = b[i][j];
               break;
           case 2:
               for (int i = 0; i < a.length / 2; i++)
                   for (int j = a.length / 2; j < a.length; j++)
                       a[i][j] = b[i][j - a.length / 2];
               break;
           case 3:
               for (int i = a.length / 2; i < a.length; i++)
                   for (int j = 0; j < a.length / 2; j++)
                       a[i][j] = b[i - a.length / 2][j];
               break;
           case 4:
               for (int i = a.length / 2; i < a.length; i++)
                   for (int j = a.length / 2; j < a.length; j++)
                       a[i][j] = b[i - a.length / 2][j - a.length / 2];
               break;
           }
       }
    
       // 常规做法
       public static int[][] multiply(int[][] a, int[][] b) {
           int n = a.length;
           int[][] c = new int[n][n];
    
           for (int i = 0; i < n; i++)
               // Initialization
               for (int j = 0; j < n; j++)
                   c[i][j] = 0;
    
           for (int i = 0; i < n; i++)
               for (int j = 0; j < n; j++)
                   for (int k = 0; k < n; k++)
                       c[i][j] += a[i][k] * b[k][j];
    
           return c;
       }
    
       public static void main(String[] args) {
           int[][] a = { { 1, 2, 6, 7 }, { 3, 4, 5, 4 }, { 5, 8, 3, 8 },
                   { -6, 4, 3, 9 } };
           int[][] b = { { 3, 4, 9, 0 }, { 7, 2, -5, -6 }, { 0, 7, -4, 6 },
                   { -6, 3, -5, 4 } };
           int[][] c = multiply(a, b);
    
           System.out.println(c[0][0] + " " + c[0][1] + " " + c[1][0] + " "
                   + c[1][1]);
       }
    }
    
    
    • 最近点对问题

    这个问题真的很有意思,给定空间上的n个节点S={(xi,yi)},如何查找这n个点对中最近的点对的距离?
    我们都知道两点间距离:((xi-xj)²+(yi-yj)²)1/2
    那么如果使用暴力搜索,需要两两检测,需要花费O(N²).
    我们可以采用分治法的思想.首先我们假设这些点都已经按照x坐标排序过,那么可以在中间画一条线,把点集分为Pl和Pr,那么最近的一对点要么都在Pl中,要么都在Pr中,要么分别在Pl和pr中,把这三个距离分别用dl,dr,dc表示.

    18
    显然我们可以递归第计算dl和dr,关键是计算dc.令δ=min(dl,dr),显然我们可以做出双道带,来缩小考虑的范围
    19
    对于在带中的,我们可以通过两层循环蛮力计算,但毕竟还是不好,最坏的情况是说有点都可能在这带状区域里.那么考虑对y坐标排序.如果pi和pj的y坐标差大于δ,那么可以直接break跳出内循环.如对于p3来说,我们只需要考虑p4和p5,之后的break就行了.
    20
    还有一点可优化的是,其实对于任意的点pi,最多有7个点需要被考虑.
    21
    如图所示,这是点最多的情况.因为如果在上图的情况中随便再添一个点,比如在左边,那么这个点距离左边其他四个点的距离肯定有小于δ的(最均匀矩形中心,距离小于δ).那么只要距离一个点小于δ.那么与之前的假设在Pl中最短的距离为δ矛盾了!所以我们只需考虑最坏情况,某个角上是pi,那么其余还有7个点要考虑,得证.
    接下来就容易了,但是还有要注意的是,不能每次递归都去排序x,y坐标.可以保留两个表,一个x坐标排序的表,一个y坐标排序的表.这个算法是O(NlogN)的.
    还是来实现吧:
      package com.fredal.structure;
    import java.util.ArrayList;
    import java.util.Set;
    import java.util.TreeSet;
    public class Distance {
       static final int NUM=10;//使用穷举法的点数   
       public static void main(String[] args) {
           Set<Point> testData = new TreeSet<Point>();         
           java.util.Random random = new java.util.Random();  
           for(int i = 0;i < 100000;i++){  
               int x = random.nextInt(100000);  
               int y = random.nextInt(100000);  
               testData.add(new Point(x, y));  
           } 
           Point [] points = new Point[testData.size()];  
           points = (Point[]) testData.toArray(points);
           Point[] result=new Point[2];
           long startTime=System.currentTimeMillis();   //获取开始时间
           result=findpair(points);
           long endTime=System.currentTimeMillis(); //获取结束时间
           System.out.println("最近点: ("+result[0].getX()+","+result[0].getY()
                   +")--("+result[1].getX()+","+result[1].getY()+")");
           System.out.println("距离为: "+distance(result[0], result[1]));
           System.out.println("分治法运行时间:"+(endTime-startTime)+"ms");
           long startTime2=System.currentTimeMillis();   //获取开始时间
           result=bong(points);
           long endTime2=System.currentTimeMillis(); //获取结束时间
           System.out.println("穷举法运行时间:"+(endTime2-startTime2)+"ms");
       }    
       //寻找最近点对
       public static Point[] findpair(Point[] p){
           Point[] result=new Point[2];
           if(p.length<NUM)
               return bong(p);
           //开始画线了  求所有点在x坐标的中位数
           int minX = (int) Double.POSITIVE_INFINITY; 
           int maxX = (int) Double.NEGATIVE_INFINITY;
           for(int i = 0; i < p.length; i++){  
               if(p[i].getX() < minX)  
                   minX = (int) p[i].getX();  
               if(p[i].getX() > maxX)  
                   maxX = (int) p[i].getX();  
           }  
           int midX = (minX + maxX)/2;
           //把以midx为界划分出的点分成两组放到两个表
           ArrayList<Point> L1=new ArrayList<Point>();
           ArrayList<Point> L2=new ArrayList<Point>();
           for(int i = 0; i < p.length; i++){  
               if(p[i].getX() <= midX)       
                   L1.add(p[i]);  
               if(p[i].getX() > midX)  
                   L2.add(p[i]);  
           } 
           //按x坐标排序
           Point [] p1 = new Point[L1.size()];  
           Point [] p2 = new Point[L2.size()];           
           L1.toArray(p1);  
           L2.toArray(p2);  
           mergeSort(p1, "x");     //按X坐标升序排列  
           mergeSort(p2, "x");     //按X坐标升序排列 
           //递归求p1,p2中最近的两个点
           Point[] result1 = new Point[2];  
           result1 = findpair(p1);
           Point[] result2 = new Point[2];  
           result2 = findpair(p2);
           //求二者中的最小值
           if (distance(result1[0], result1[1])<distance(result2[0], result2[1])) {
               result=result1;
           }else {
               result=result2;
           }
           double distance=Math.min(distance(result1[0], result1[1]),distance(result2[0], result2[1]));
           //开始划分带了  在两个子集中找哪些距离划分线小于d的保存
           ArrayList<Point> L3 = new ArrayList<Point>();   
           for(int i = 0; i < p1.length; i++){  
               if(midX - p1[i].getX() < distance)  
                   L3.add(p1[i]);  
           }  
           for(int i = 0; i < p2.length; i++){  
               if(p2[i].getX() - midX < distance){  
                   L3.add(p2[i]);  
               }  
           } 
           //将得到的按照y升序排列
           Point [] p3 = new Point [L3.size()];  
           L3.toArray(p3);            
           mergeSort(p3, "y");  
           //然后开始优化的穷举 即比较之后的7个点
           if (p3.length<NUM) {
               Point[] temp= bong(p3);
               if (distance(temp[0], temp[1])<distance&&distance(temp[0], temp[1])!=0) {
                   result=temp;
               }
           }else {
               for(int i=0;i<p3.length-7;i++){
                   double tempd;
                   for(int j=1;j<8;j++){
                       if (i+j>=p3.length) {
                           break;
                       }else {
                           tempd=distance(p3[i], p3[i+j]);
                           if (tempd<distance&tempd!=0) {
                               result[0]=p3[i];
                               result[1]=p3[i+j];
                           }
                       }                 
                   }
               }
           }
           return result;
       }    
       //归并排序
       private static void mergeSort(Point[] p, String flag) {
           Point[] result = new Point[p.length];  
           mergeSort(p, result, 0, p.length - 1, flag); 
       }
       private static void mergeSort(Point[] a, Point [] result, int left, int right, String flag){  
           if(left < right){  
               int center = (left + right) >> 1;  
               //分治  
               mergeSort(a, result, left, center, flag);  
               mergeSort(a, result, center + 1, right, flag);  
               //合并  
               merge(a, result, left, center + 1, right, flag);  
           }  
         } 
       private static void merge(Point [] a, Point [] result, int leftPos, int rightPos, int rightEnd, String flag){  
           int leftEnd = rightPos - 1;  
           int numOfElements = rightEnd - leftPos + 1;  
             
           int tmpPos = leftPos;       //游标变量, 另两个游标变量分别是leftPos 和 rightPos            
           while(leftPos <= leftEnd && rightPos <= rightEnd){  
               if(flag.equals("x")){  
                   if(a[leftPos].getX() <= a[rightPos].getX())  
                       result[tmpPos++] = a[leftPos++];  
                   else  
                       result[tmpPos++] = a[rightPos++];  
               }else if(flag.equals("y")){  
                   if(a[leftPos].getY() <= a[rightPos].getY())  
                       result[tmpPos++] = a[leftPos++];  
                   else  
                       result[tmpPos++] = a[rightPos++];  
               }else  
                   throw new RuntimeException();  
           }       
           while(leftPos <= leftEnd)  
               result[tmpPos++] = a[leftPos++];  
           while(rightPos <= rightEnd)  
               result[tmpPos++] = a[rightPos++];         
           //将排好序的段落拷贝到原数组中  
           System.arraycopy(result, rightEnd-numOfElements+1, a, rightEnd-numOfElements+1, numOfElements);  
       }      
       //穷举法
       private static Point[] bong(Point[] p){
           Point[] result=new Point[2];
           if (p.length<=1) {
               result[0]=new Point(Double.MIN_VALUE, Double.MIN_VALUE);
               result[1]=new Point(Double.MAX_VALUE, Double.MAX_VALUE);
               return result;
           }else{
               double min=distance(p[0], p[1]);
               int start=0;
               int end=1;
               for (int i = 0; i < p.length; i++) {
                   for (int j = i+1; j < p.length; j++) {
                       if (distance(p[i], p[j])<min&&distance(p[i], p[j])!=0) {
                           min=distance(p[i], p[j]);
                           start=i;
                           end=j;
                       }
                   }
               }
               result[0]=p[start];
               result[1]=p[end];
               return result;
           }
       }    
       //计算距离
       private static double distance(Point p1,Point p2){
           return Math.sqrt((p1.getX()-p2.getX())*(p1.getX()-p2.getX())+(p1.getY()-p2.getY())*(p1.getY()-p2.getY()));
       }
       //用一个类来表示点
       static class Point implements  Cloneable,Comparable<Point>{
           double x,y;
           public Point(double x, double y) {
               super();
               this.x = x;
               this.y = y;
           }
           public double getX() {
               return x;
           }
           public double getY() {
               return y;
           }
           public int compareTo(Point o) {  
               if(x == o.getX() && y == o.getY())  
                   return 0;  
               else   
                   return 1;  
           }
           public boolean equals(Object p) {
               // TODO 自动生成的方法存根
               if (this.x==((Point) p).getX()&&this.y==((Point) p).getY()) {
                   return true;
               }else {
                   return false;
               }
           }
       }    
    }
    

    我们采用了随机数产生器随机产生点,并统计了分治法与纯粹暴力搜索的运行时间,差距相当明显(这个程序当规模特别大的时候还是会栈溢出,为啥)

    22

    泊松分酒问题

    讲这道题纯粹就是比较好玩,就记录一下.泊松分酒是很著名的一道题,讲的是假设某人有12品脱的啤酒一瓶,想从中倒出六品脱,但是恰巧身边没有6品脱的容器,仅有一个8品脱和一个5品脱的容器,怎样倒才能将啤酒分为两个6品脱呢?
    我们用代码模拟很简单的就得到了答案

      package com.fredal.structure; import java.util.HashSet;
    import java.util.LinkedList;
    import java.util.Set;
    public class Oil {
        static class Status{
             static int[] full={12,8,5};//满的状态
             int[] bottle=new int[3];//瓶子的状态
             Status from;//从哪个状态来的
             
             public Status(int a,int b,int c){
                 bottle[0]=a;
                 bottle[1]=b;
                 bottle[2]=c;
             }
             
             //获取某种状态开始下一步的所有的状态
             public Set opreation(){
                 Set res=new HashSet();
                 
                 //开始倒酒
                 for(int i=0;i<bottle.length;i++){
                     for(int j=0;j<bottle.length;j++){
                         if(i==j) continue; //不倒自己
                         if(bottle[i]==0) continue;//自己是空的 不倒
                         if(bottle[j]==full[j]) continue;//对方是满的 不倒
                         
                         Status t=new Status(bottle[0], bottle[1], bottle[2]);
                         t.from=this;//从自己这个状态开始变化
                         
                         //真的开始倒酒了                     t.bottle[j]+=t.bottle[i];
                         t.bottle[i]=0;
                         if(t.bottle[j]>full[j]){//装不下了
                             t.bottle[i]=t.bottle[j]-full[j];//满的倒回去
                             t.bottle[j]=full[j];
                         }              
                         res.add(t);
                     }
                 }
                 return res;
             }
             
             //是否含有某种状态
             public boolean has2(int x){
                int index=0;
                if (bottle[0]==x) index++;
                if (bottle[1]==x) index++;
                if (bottle[2]==x) index++;
                return index==2?true:false;
             }
             
             public Status getFrom() {
                return from;
            }
             
             public String toString(){
                    return "<" + bottle[0] + "," + bottle[1] + "," + bottle[2] + ">";
            }
             
            public int hashCode() {
                return 100;
            }
            
            public boolean equals(Object obj) {
                Status x=(Status)obj;
                return bottle[0]==x.bottle[0]&&bottle[1]==x.bottle[1]&&bottle[2]==x.bottle[2];
            }
        }
        
        public static void main(String[] args) {
            Set<Status> all=new HashSet<Status>();//存放所有结果状态
            all.add(new Status(12, 0, 0));
            
            for(;;){
                Set newset=new HashSet();
                
                for(Status x:all){//所有上一种状态产生所有下一种状态
                    Set t = x.opreation();
                    newset.addAll(t);
                }
                
                if(all.containsAll(newset)) break;//出口
                all.addAll(newset);
            }
            
            LinkedList<Status> list=new LinkedList<Status>();//存放有6的一溜
            
            for(Status k:all){
                if(k.has2(6)){
                    while(k!=null){    
                        list.push(k);
                        k=k.getFrom();//从终止状态开始往上追溯
                    }
                }
            }   
            //输出
            while(!list.isEmpty()){
                System.out.println(list.pop());
            }
        }    
    }
    

    这个解法找到的其实是最优解,至于为什么呢,其实利用set的方法十分巧妙,结果集set里随着一次次的分酒一次次地扩增,当第一次出现含有两个6的状态的时候,再往前追溯,步骤是最少的!因为这个我们想要的状态是第一次出现.
    假如我们每次都打印出all集合,可以知道,当第一次找到含有两个6状态的时候程序并没有结束,因为还没有找到所有的状态.
    而后面的状态再进行分酒时,仍有可能产生两个6的状态,但是想要加入set集合的时候就行不通了,所以此程序只输出最早加入的那一个解,并且是最优的.
    当然这种算法并不能输出所有的解,如果要得到所有的解,我们可以采用以下算法,这种算法借鉴了图的深度搜索(DFS)以及回溯的技巧,需要注意的是,和8皇后问题一样,需要回溯的时机有两个,出错的时候和找到某一组解的时候.

      package com.fredal.structure;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Scanner;
    
    public class Oil {
        int[] full = new int[3]; //满状态 容量
        int[] bottle = new int[3]; //瓶子的状态
        int target = 0; //目标
        List<int[]> res = new ArrayList<int[]>();//存放结果
        
        public void opreation(int[] bottle) {
            for(int i=0;i<3;i++) {
                for(int j=1;j<3;j++){//每个瓶子都不往自己倒 总共6种可能性
                    int[] temp = bottle.clone();//每次循环都创建临时数组 
                    int to=(i+j)%3;//(i+j)%3 是除每种i瓶子外其他两个瓶子的序号,即要倒的目标
                    if(temp[i]==0) continue;//自己是空的 不倒
                    if(temp[to]==full[to]) continue;//对方是满的 不倒
                    
                    //开始倒酒
                    temp[to]+=temp[i];
                    temp[i]=0;
                    if(temp[to]>full[to]){//装不下了
                        temp[i]=temp[to]-full[to];//满出来的部分倒回去
                        temp[to]=full[to];
                    }
                    
                    if(had(temp)) continue;//检测是否已经存在相同状态,防止重复
    
                    res.add(temp);//添加到结果链表
                    if(has2(temp))    return;//如果找到有两个想要的状态的结果就返回
                    opreation(temp);//继续下一次分酒
                    res.remove(res.size()-1); //回溯 仔细体会
                }
            }
        }
    
        //是否以及含有状态
        private boolean had(int[] bottlex) {
            for(int[] e:res)
                if(e[0]==bottlex[0]&&e[1]==bottlex[1]&&e[2]==bottlex[2]) return true;
            return false;
        }
    
        //检测找到结果
        private boolean has2(int[] bottle) {
            int index=0;
            for(int i=0;i<bottle.length;i++)        
                if(bottle[i]==target) index++;        
            if(index==2){
                show(res);//输出
                res.remove(res.size()-1);//回溯
                return true;
            }
            return false;
        }
        //打印
        private void show(List<int[]> res) {
            for(int[] e:res) {
                System.out.println(e[0] + "," + e[1] + "," + e[2]);
            }
            System.out.println();
        }
    
        public static void main(String[] args) {
            Oil o = new Oil();
            Scanner scanner = new Scanner(System.in);
            String s =""; 
            if(scanner.hasNext()) {
                s = scanner.nextLine();
            }
            String[] data = s.split(",");
            int[] d = new int[data.length];
            for(int i=0;i<data.length;i++){
                d[i] = Integer.parseInt(data[i]);
            }
            o.full = new int[]{d[0],d[1],d[2]};
            o.bottle = new int[]{d[3],d[4],d[5]};
            o.target = d[6];
            o.res.add(new int[]{d[3],d[4],d[5]});//添加初始状态
            o.opreation(o.bottle);
        }
        
    }
    

    显然,按照深度搜索并不能有效地找到最优解.上面两种算法都是比较巧的,我也比较喜欢.
    如果要同时找到所有解和最优解,用图的广度搜索(BFS)会很方便,这也是网上采用的最多的,代码到处都有,就不写了.
    更多内容与相关下载请查看扩展阅读

    相关文章

      网友评论

      • 衤刀学者:这个文章我感觉非常棒。可以结合一些校招试题来讲解。今年的校招笔试题装箱问题是个热门话题。我已经收录了。打算转载。
      • 陈达也:很赞!是否可以转载?
        fredal:@陈达也 可以`~

      本文标题:五大常用算法二(贪心,分治)

      本文链接:https://www.haomeiwen.com/subject/pxfikttx.html