美文网首页
A_star算法(人工智能1)

A_star算法(人工智能1)

作者: 小火伴 | 来源:发表于2018-01-17 22:44 被阅读33次
    01.png 02.png 03.png 04.png 05.png 06.png 07.png 08.png 09.png 10.png 11.png

    head.h

    #include <iostream>
    #include <vector>
    #include <queue>
    #include <string>
    #include <list>
    #include <algorithm>
    
    
    using namespace std;
    
    class Node {
        //节点类,收集所有的不同的被需要的元素,例如估计s->n最小代价g,n-1->n估计最小代价h,深度
    public:
        vector<int> state;
        int gn;
        int hn;
        int depth;
        int total;
        int blank;
        vector<int> parent;
        Node(vector<int> a, int b, int c, int d, int e, int f, vector<int> g)
        {
            this->state = a;
            this->gn = b;
            this->hn = c;
            this->depth = d;
            this->total = e;
            this->blank = f;
            this->parent = g;
        }
    };
    
    vector<Node> merge(vector<Node> a, vector<Node> b)
    //结合归并排序,把两个容器混合成一个
    {
        vector<Node> c;
        int i = 0;
        int j = 0;
        while ((i < a.size()) && (j < b.size()))
        {
            Node look = a.at(i);
            Node there = b.at(j);
            if (look.total > there.total)
            {
                c.push_back(there);
                ++j;
            }
            else
            {
                c.push_back(look);
                ++i;
            }
        }
        while (i < a.size())
        {
            Node look = a.at(i);
            c.push_back(look);
            ++i;
        }
        while (j < b.size())
        {
            Node there = b.at(j);
            c.push_back(there);
            ++j;
        }
        return c;
    }
    
    vector<Node> mergesort(vector<Node> arr)
    //归并排序
    {
        if (arr.size() == 1)
        {
            return arr;
        }
        vector<Node> l1;
        int left = arr.size() / 2;
        for (int i = 0; i < left; ++i)
        {
            l1.push_back(arr.at(i));
        }
        vector<Node> r1;
        for (int i = left; i < arr.size(); ++i)
        {
            r1.push_back(arr.at(i));
        }
        l1 = mergesort(l1);
        r1 = mergesort(r1);
    
        return merge(l1, r1);
    }
    
    void orderSort(queue<Node> &input)
    //调用上面的归并排序
    {
        vector<Node> holder;
        while (input.size() > 0)
        {
            holder.push_back(input.front());
            input.pop();
        }
        holder = mergesort(holder);
        for (int i = 0; i < holder.size(); ++i)
        {
            input.push(holder.at(i));
        }
    }
    
    int manDist(vector<int> a, vector<int> &b)
    //计算曼哈顿距离
    {
        int goal[3][3] = { { 1, 2, 3 },{ 4, 5, 6 },{ 7, 8, 0 } };
        int comp[3][3] = { { a.at(0), a.at(1), a.at(2) },{ a.at(3), a.at(4), a.at(5) },{ a.at(6), a.at(7), a.at(8) } };
        int totaldiff = 0;
        int deltx = 0;
        int delty = 0;
        int look = 0;
        for (int k = 0; k < 3; ++k)
        {
            for (int l = 0; l < 3; ++l)
            {
                look = comp[k][l];
                for (int i = 0; i < 3; ++i)
                {
                    for (int j = 0; j < 3; ++j)
                    {
                        if (goal[i][j] == look)
                        {
                            deltx = abs(k - i);
                            delty = abs(l - j);
                            totaldiff = (totaldiff + (deltx + delty));
                        }
                    }
                }
            }
        }
        return totaldiff;
    }
    
    int missPlaced(vector<int> a, vector<int> &b)
    //错位距离
    {
        int diff = 0;
        for (int i = 0; i < a.size(); ++i)
        {
            if (a.at(i) != b.at(i))
            {
                ++diff;
            }
        }
        return diff;
    }
    
    void move(Node expand, queue<Node> &a, vector<int> &b, int &heurType)
    //扩展所有节点
    {
        int expandgn = expand.gn;
        int expandhn = expand.hn;
        int expandDepth = expand.depth;
        int expandTotal = expand.total;
        int expandBlank = expand.blank;
        int check = expand.blank;
        vector<int> modify = expand.state;
        vector<int> use = modify;
        Node right = Node(use, 1, 0, 1, 1, 1, modify);
        bool right1 = false;
        Node left = Node(use, 1, 0, 1, 1, 1, modify);
        bool left1 = false;
        Node up = Node(use, 1, 0, 1, 1, 1, modify);
        bool up1 = false;
        Node down = Node(use, 1, 0, 1, 1, 1, modify);
        bool down1 = false;
        //输出拓展情况
        /*cout << "拓展节点: " << endl;
        for(int i=0; i < modify.size(); ++i) //
        {
        if((i == 3) || (i == 6))
        {
        cout << endl;
        }
        cout << modify.at(i) << " ";
        }
        cout << endl;
        cout << " g(n) = " << expandgn << "  h(n) = " << expandhn << endl;*/
        int temp = 0;
        int missHeur = 0;
        int manHeur = 0;
        expandgn = expandgn + 1;
        if (check + 1 <= 8)
            //检查能否向右边移动
        {
            if ((check != 2) && (check != 5))
                //不在边上
            {
                temp = use.at(check);
                use.at(check) = use.at(check + 1);
                use.at(check + 1) = temp;
                if (use != expand.parent)   //检查确定不是父节点(重复的状态)
                {
                    right1 = true;
                    if (heurType == 1)  //等代价搜索(只有深度代价h,没有节点之间代价g)
                    {
                        right = Node(use, expandgn, 0, (expandDepth + 1), expandgn, (expandBlank + 1), modify);
                    }
                    else if (heurType == 2) //错位距离
                    {
                        missHeur = missPlaced(use, b);
                        right = Node(use, expandgn, missHeur, (expandDepth + 1), (expandgn + missHeur), (expandBlank + 1), modify);
                    }
                    else    //曼哈顿距离
                    {
                        manHeur = manDist(use, b);
                        right = Node(use, expandgn, manHeur, (expandDepth + 1), (expandgn + manHeur), (expandBlank + 1), modify);
                    }
                }
            }
        }
        use = modify;
        if (check - 1 >= 0) //可以左移
        {
            if ((check != 3) && (check != 6))   //不在边上
            {
                temp = use.at(check);
                use.at(check) = use.at(check - 1);
                use.at(check - 1) = temp;
                if (use != expand.parent)
                {
                    left1 = true;
                    if (heurType == 1)  //等代价搜索(只有深度代价h,没有节点之间代价g)
                    {
                        left = Node(use, expandgn, 0, (expandDepth + 1), expandgn, (expandBlank - 1), modify);
                    }
                    else if (heurType == 2) //错位距离
                    {
                        missHeur = missPlaced(use, b);
                        left = Node(use, expandgn, missHeur, (expandDepth + 1), (expandgn + missHeur), (expandBlank - 1), modify);
                    }
                    else        //曼哈顿距离
                    {
                        manHeur = manDist(use, b);
                        left = Node(use, expandgn, manHeur, (expandDepth + 1), (expandgn + manHeur), (expandBlank - 1), modify);
                    }
                }
            }
        }
        use = modify;
        if (check + 3 <= 8)     //可以上移
        {
            temp = use.at(check);
            use.at(check) = use.at(check + 3);
            use.at(check + 3) = temp;
            if (use != expand.parent)   //没有重复
            {
                up1 = true;
                if (heurType == 1)  //等代价搜索(只有深度代价h,没有节点之间代价g)
                {
                    up = Node(use, expandgn, 0, (expandDepth + 1), (expandgn + 1), (expandBlank + 3), modify);
                }
                else if (heurType == 2)     //错位距离
                {
                    missHeur = missPlaced(use, b);
                    up = Node(use, expandgn, missHeur, (expandDepth + 1), (expandgn + missHeur), (expandBlank + 3), modify);
                }
                else        //曼哈顿距离
                {
                    manHeur = manDist(use, b);
                    up = Node(use, expandgn, manHeur, (expandDepth + 1), (expandgn + manHeur), (expandBlank + 3), modify);
                }
            }
        }
        use = modify;
        if (check - 3 >= 0) //可以向下移
        {
            temp = use.at(check);
            use.at(check) = use.at(check - 3);
            use.at(check - 3) = temp;
            if (use != expand.parent)   //没有重复
            {
                down1 = true;
                if (heurType == 1)      //等代价搜索(只有深度代价h,没有节点之间代价g)
                {
                    down = Node(use, expandgn, 0, (expandDepth + 1), (expandgn + 1), (expandBlank - 3), modify);
                }
                else if (heurType == 2)     //错位距离
                {
                    missHeur = missPlaced(use, b);
                    down = Node(use, expandgn, missHeur, (expandDepth + 1), (expandgn + missHeur), (expandBlank - 3), modify);
                }
                else        //曼哈顿距离
                {
                    manHeur = manDist(use, b);
                    down = Node(use, expandgn, manHeur, (expandDepth + 1), (expandgn + manHeur), (expandBlank - 3), modify);
                }
            }
        }
        if (right1) //只添加已经生成的节点
        {
            a.push(right);
        }
        if (left1)
        {
            a.push(left);
        }
        if (up1)
        {
            a.push(up);
        }
        if (down1)
        {
            a.push(down);
        }
        if (heurType != 1)  //因为启发式搜索成本会变化,所以不对等代价搜索排序
        {
            orderSort(a);
        }
    }
    
    void mySearch(vector<int> &start, vector<int> &finish, int &blank, int &score)
    {
        queue<Node> work;   //所有节点的队列OPEN
        vector<int> par;
        work.push(Node(start, 0, 0, 0, 0, blank, par));     //创建初始节点
        Node temp = work.front();
        bool done = false;
        bool success = false;
        int added = 0;
        int maxSize = work.size();
        while (!done)           //没有找到
        {
            if (work.size() > maxSize)      //最大队列长度
            {
                maxSize = work.size();
            }
    
            if (work.size() > 0)        //队列长度大于0
            {
                temp = work.front();
                work.pop();
                if (temp.state == finish)//目标状态
                {
                    done = true;
                    success = true;
                }
                else
                {
                    move(temp, work, finish, score);    //拓展当前节点
                    ++added;            //看看我们已经探索的节点数
                }
            }
            else
            {
                done = true;
                success = false;
            }
        }
        if (success)        //目标找到
        {
            cout << "发现目标" << endl;
            cout << "求结这个问题算法探索了 " << added << " 个节点。" << endl;
            cout << "队列中节点数峰值为" << maxSize << "。" << endl;
            cout << "目标节点的深度是" << temp.depth << "" << endl;
        }
    }
    

    main.cpp

    #include "head.h"
    #include <time.h>
    
    #include <iostream>
    using namespace std;
    void getInput(vector<int> &store) //接收用户输入
    {
        char input;
        int keep;
        bool invalid = false;
        vector<int> temp;
        for (int i = 0; i < 3; ++i)
        {
            cin >> input;
            if ((input < '0') || (input > '8') && (input != ' ')) //异常值检查
            {
                invalid = true;
            }
            else
            {
                if (input != ' ') //临时存储图
                {
                    keep = input - '0';
                    temp.push_back(keep);
                }
            }
        }
        if (invalid) //异常值报错并循环请求
        {
            cout << "输入了无效字符,请重新输入当前行";
            temp.clear();       //清空本次输入
            getInput(temp);
        }
        for (int i = 0; i < temp.size(); ++i) //真正存储当前图
        {
            store.push_back(temp.at(i));
        }
    }
    
    int main()
    {
        cout << "示例:" << endl << "4 1 3 " << endl << "7 5 0 " << endl << "8 2 6" << endl;
        cout << "1. **Uniform cost search:**" << endl;
        cout << "* _Goal_ = found" << endl;
        cout << "* _Nodes expanded_ = 395*"<<endl;
        cout<< "_Max number of nodes in queue_ = 312" << endl;
        cout << "2. **A* with misplaced heuristic : **" << endl;
        cout << "* _Goal_ = found" << endl;
        cout << "* _Nodes expanded_ = 26" << endl;
        cout << "* _Max number of nodes in queue_ = 24" << endl;
        cout << "3. **A* with manhattan distance : **" << endl;
        cout << "* _Goal_ = found" << endl;
        cout << "* _Nodes expanded_ = 22" << endl;
        cout << "* _Max number of nodes in queue_ = 21" << endl << endl;
    
        vector<int> init;
        cout << "请输入你的初始状态,用0表示空的格" << endl;
        bool okay = false;
        int counter = 0;
        while (!okay) //获取用户输入
        {
            counter = 0;
            cout << "输入第一行:";
            getInput(init);
            cout << "第二行:";
            getInput(init);
            cout << "三:";
            getInput(init);
            for (int i = 0; i < init.size(); ++i) //查看空格数量
            {
                if (init.at(i) == 0)
                {
                    counter++;
                }
            }
            if (counter == 0)
            {
                cout << "你没输入空格,请重新输入" << endl;
                init.clear();
            }
            else if (counter == 1)
            {
                okay = true;
            }
            else
            {
                cout << "你输入不只一个空格,请重新输入" << endl;
                init.clear();
            }
        }
        cout << endl << "你输入的图是 " << endl;
        for (int i = 0; i < init.size(); ++i) //输出初始图
        {
            if ((i == 3) || (i == 6))
            {
                cout << endl;
            }
            cout << init.at(i) << " ";
        }
        cout << endl;
        vector<int> goal;   //目标状态
        for (int i = 1; i < init.size(); ++i)
        {
            goal.push_back(i);
        }
        goal.push_back(0);
        int blankLoc = 0;   //找出空白位置
        for (int i = 0; i < init.size(); ++i)
        {
            if (init.at(i) == 0)
            {
                blankLoc = i;
            }
        }
        cout << "你的空格在: " << blankLoc << endl << endl;
        okay = false;
        int choice = 0;
        while (!okay)
        {
            cout << "请选择你想用哪一个算法 " << endl;
            cout << "1. 宽度优先搜索" << endl;
            cout << "2. 用错位误差的A*" << endl;
            cout << "3. 用曼哈顿距离的A*" << endl;
            cout << "你选的是:";
            cin >> choice;
            if ((choice >= 1) && (choice <= 3))
            {
                okay = true;
            }
            else
            {
                cout << "选择无效" << endl;
            }
        }
        cout << endl;
        //计算运行时间
        clock_t start, finish;
        double totaltime;
        start = clock();
        if (choice == 1)
        {
    
            mySearch(init, goal, blankLoc, choice);
    
        }
        else if (choice == 2)
        {
            mySearch(init, goal, blankLoc, choice);
        }
        else
        {
            mySearch(init, goal, blankLoc, choice);
        }
        finish = clock();
        totaltime = (double)(finish - start) / CLOCKS_PER_SEC;
        cout << "此程序的运行时间为" << totaltime << "秒!" << endl;
        char hh;
        cin>>hh;
        return 0;
    }
    
    

    相关文章

      网友评论

          本文标题:A_star算法(人工智能1)

          本文链接:https://www.haomeiwen.com/subject/agdpoxtx.html