美文网首页JavaJava 杂谈
代码实现TCP三次握手:程序实现

代码实现TCP三次握手:程序实现

作者: 望月从良 | 来源:发表于2019-08-14 17:47 被阅读3次

    本节我们通过代码来实现TCP协议连接时的三次握手过程。首先我们需要再次重温一下TCP数据包的相关结构:


    屏幕快照 2019-07-27 下午4.43.06.png

    我们们将依照上面结构所示来构建数据包,相关代码如下:

    public class TCPProtocolLayer implements IProtocol {
        private static int HEADER_LENGTH = 20;
        private int sequence_number = 2;
        private int acknowledgement_number = 0;
        private static int PSEUDO_HEADER_LENGTH = 12;
        public static byte TCP_PROTOCOL_NUMBER = 6;
        private static int POSITION_FOR_DATA_OFFSET = 12;
        private static int POSITION_FOR_CHECKSUM = 16;
        private static byte MAXIMUN_SEGMENT_SIZE_OPTION_LENGTH = 4;
        private static byte MAXIMUN_SEGMENT_OPTION_KIND = 2;
        private static byte WINDOW_SCALE_OPTION_KIND = 3;
        private static byte WINDOW_SCALE_OPTION_LENGTH = 3;
        private static byte WINDOW_SCALE_SHIFT_BYTES = 6;
        private static byte TCP_URG_BIT = (1 << 5);
        private static byte TCP_ACK_BIT = (1 << 4);
        private static byte TCP_PSH_BIT = (1 << 3);
        private static byte TCP_RST_BIT = (1 << 2);
        private static byte TCP_SYN_BIT = (1 << 1);
        private static byte TCP_FIN_BIT = (1);
        @Override
        public byte[] createHeader(HashMap<String, Object> headerInfo) {
            short data_length = 0;
            byte[] data = null;
            if (headerInfo.get("data") != null) {
                data = (byte[])headerInfo.get("data");
            }
            byte[] header_buf = new byte[HEADER_LENGTH];
            ByteBuffer byteBuffer = ByteBuffer.wrap(header_buf);
            if (headerInfo.get("src_port") == null) {
                return null;
            }
            short srcPort = (short)headerInfo.get("src_port");
            byteBuffer.putShort(srcPort);
            if (headerInfo.get("dest_port") == null) {
                return  null;
            }
            short  destPort = (short)headerInfo.get("dest_port");
            byteBuffer.putShort(destPort);
            
            //设置初始序列号
            if (headerInfo.get("seq_num") != null) {
                sequence_number = (int)headerInfo.get("seq_num");
            }
            if (headerInfo.get("ack_num") != null) {
                acknowledgement_number = (int)headerInfo.get("ack_num");
            }
            byteBuffer.putInt(sequence_number); 
            byteBuffer.putInt(acknowledgement_number);
            short control_bits = 0;
            //设置控制位
            if (headerInfo.get("URG") != null) {
                control_bits |= (1 << 5);
            }
            if (headerInfo.get("ACK") != null) {
                control_bits |= (1 << 4);
            }
            if (headerInfo.get("PSH") != null) {
                control_bits |= (1 << 3);
            }
            if (headerInfo.get("RST") != null) {
                control_bits |= (1 << 2);
            }
            if (headerInfo.get("SYN") != null) {
                control_bits |= (1 << 1);
            }
            if (headerInfo.get("FIN") != null) {
                control_bits |= (1);
            }
            byteBuffer.putShort(control_bits);
            System.out.println(Integer.toBinaryString(control_bits));
            
            char window = 65535;
            byteBuffer.putChar(window);
            short check_sum = 0;
            byteBuffer.putShort(check_sum);
            short urgent_pointer = 0;
            byteBuffer.putShort(urgent_pointer);
            
            byte[] maximun_segment_option = new byte[MAXIMUN_SEGMENT_SIZE_OPTION_LENGTH];
            ByteBuffer maximun_segment_buffer =  ByteBuffer.wrap(maximun_segment_option);
            maximun_segment_buffer.put(MAXIMUN_SEGMENT_OPTION_KIND);
            maximun_segment_buffer.put(MAXIMUN_SEGMENT_SIZE_OPTION_LENGTH);
            short segment_size = 1460;
            maximun_segment_buffer.putShort(segment_size);
            
            byte[] window_scale_option = new byte[WINDOW_SCALE_OPTION_LENGTH];
            ByteBuffer window_scale_buffer = ByteBuffer.wrap(window_scale_option);
            window_scale_buffer.put(WINDOW_SCALE_OPTION_KIND);
            window_scale_buffer.put(WINDOW_SCALE_OPTION_LENGTH);
            window_scale_buffer.put(WINDOW_SCALE_SHIFT_BYTES);
            
            byte[] option_end = new byte[1];
            option_end[0] = 0;
            
            int total_length = data_length + header_buf.length + maximun_segment_option.length + window_scale_option.length + option_end.length;
            //总长度必须是4的倍数,不足的话以0补全
            if (total_length % 4 != 0) {
                total_length = (total_length / 4 + 1) * 4;
            }
            byte[] tcp_buffer = new byte[total_length];
            ByteBuffer buffer = ByteBuffer.wrap(tcp_buffer);
            buffer.put(header_buf);
            buffer.put(maximun_segment_option);
            buffer.put(window_scale_option);
            buffer.put(option_end);
            short data_offset = buffer.getShort(POSITION_FOR_DATA_OFFSET);
            data_offset |= (((total_length / 4) & 0x0F) << 12);
            System.out.println(Integer.toBinaryString(data_offset));
            buffer.putShort(POSITION_FOR_DATA_OFFSET, data_offset);
            check_sum = (short)compute_checksum(headerInfo, buffer);
            buffer.putShort(POSITION_FOR_CHECKSUM, check_sum);
            return buffer.array();
        }
        
        private long compute_checksum(HashMap<String, Object> headerInfo, ByteBuffer tcp_buffer) {
            byte[] pseudo_header = new byte[PSEUDO_HEADER_LENGTH];
            ByteBuffer pseudo_header_buf = ByteBuffer.wrap(pseudo_header);
            byte[] src_addr = (byte[])headerInfo.get("src_ip");
            byte[] dst_addr = (byte[])headerInfo.get("dest_ip");
            pseudo_header_buf.put(src_addr);
            pseudo_header_buf.put(dst_addr);
            byte reserved = 0;
            pseudo_header_buf.put(reserved);
            pseudo_header_buf.put(TCP_PROTOCOL_NUMBER);
            short tcp_length = (short)tcp_buffer.array().length;
            //将伪包头和tcp包头内容合在一起计算校验值
            byte[] total_buffer = new byte[PSEUDO_HEADER_LENGTH + tcp_buffer.array().length];
            ByteBuffer total_buf = ByteBuffer.wrap(total_buffer);
            total_buf.put(pseudo_header);
            total_buf.put(tcp_buffer.array());
            return Utility.checksum(total_buffer, total_buffer.length);
        }
    
        @Override
        public HashMap<String, Object> handlePacket(Packet packet) {
            ByteBuffer buffer= ByteBuffer.wrap(packet.header);
            HashMap<String, Object> headerInfo = new HashMap<String, Object>();
            short src_port = buffer.getShort();
            headerInfo.put("src_port", src_port);
            short dst_port = buffer.getShort();
            headerInfo.put("dest_port", dst_port);
            int seq_num = buffer.getInt();
            headerInfo.put("seq_num", seq_num);
            int ack_num = buffer.getInt();
            headerInfo.put("ack_num", ack_num);
            short control_bits = buffer.getShort();
            if ((control_bits & TCP_ACK_BIT) != 0) {
                headerInfo.put("ACK", 1);
            }
            if ((control_bits & TCP_SYN_BIT) != 0) {
                headerInfo.put("SYN", 1);
            }
            if ((control_bits & TCP_FIN_BIT) != 0) {
                headerInfo.put("FIN", 1);
            }
            short win_size = buffer.getShort();
            headerInfo.put("window", win_size);
            //越过校验值
            buffer.getShort();
            short urg_pointer = buffer.getShort();
            headerInfo.put("urg_ptr", urg_pointer);
            return headerInfo;
        }
    }
    

    上面代码实现了协议层TCP的封包与解包,在函数createHeader中,我们按照上图结构填写相关包头的字段,在函数handlePacket中,我们根据包头的字段获取相应信息。

    在ProtocolManager中转层,我们实现下面代码:

    private void handleTCPPacket(Packet packet,  HashMap<String, Object> infoFromUpLayer) {
            IProtocol tcpProtocol = new TCPProtocolLayer();
            HashMap<String, Object> headerInfo = tcpProtocol.handlePacket(packet);
            short dstPort = (short)headerInfo.get("dest_port");
            //根据端口获得应该接收UDP数据包的程序
            IApplication app = ApplicationManager.getInstance().getApplicationByPort(dstPort);
            if (app != null) {
                app.handleData(headerInfo); 
            }
        }
    

    一旦程序通过JPCap收到TCP包后,它会让上面实现的TCPProtocolLayer去解析数据包内的各个字段,然后检测数据包对应的端口是否在应用层有对应的接收对象,如果有的话,它就将解析信息转交给应用层的接收对象,接下来我们看应用层的相关实现:

    public class TCPThreeHandShakes extends Application{
        private byte[] dest_ip;
        private short dest_port;
        private int ack_num = 0;
        private int seq_num = 0;
        public TCPThreeHandShakes(byte[] server_ip, short server_port) {
            this.dest_ip = server_ip;
            this.dest_port = server_port;
             //指定一个固定端口,以便抓包调试  
            this.port = (short)11940;
        }
        
       public void beginThreeHandShakes() throws Exception {
           createAndSendPacket(null, "SYN");
       }
       
       private void createAndSendPacket(byte[] data, String flags) throws Exception {
           byte[] tcpHeader = createTCPHeader(null, flags);
           if (tcpHeader == null) {
                throw new Exception("tcp Header create fail");
            }   
           byte[] ipHeader = createIP4Header(tcpHeader.length);
           byte[] packet  = new byte[tcpHeader.length + ipHeader.length];
           ByteBuffer packetBuffer = ByteBuffer.wrap(packet);
           packetBuffer.put(ipHeader);
           packetBuffer.put(tcpHeader);
           sendPacket(packet);
       }
       
       private void sendPacket(byte[] packet) {
           try {
                InetAddress ip = InetAddress.getByName("192.168.2.1");
                ProtocolManager.getInstance().sendData(packet, ip.getAddress());
            } catch (Exception e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
       }
       
       private byte[] createTCPHeader(byte[] data, String flags) {
           IProtocol tcpProto = ProtocolManager.getInstance().getProtocol("tcp");
            if (tcpProto == null) {
                return null;
            }
            HashMap<String, Object> headerInfo = new HashMap<String, Object>();
            byte[] src_ip = DataLinkLayer.getInstance().deviceIPAddress();
            headerInfo.put("src_ip", src_ip);
            headerInfo.put("dest_ip", this.dest_ip);
            headerInfo.put("src_port", (short)this.port);
            headerInfo.put("dest_port", this.dest_port);
            headerInfo.put("seq_num", seq_num);
            headerInfo.put("ack_num", ack_num);
            String[] flag_units = flags.split(","); 
            for(int i = 0; i < flag_units.length; i++) {
                headerInfo.put(flag_units[i], 1);
            }
            
            byte[] tcpHeader = tcpProto.createHeader(headerInfo);
            return tcpHeader;
       }
       
       protected byte[] createIP4Header(int dataLength) {
            IProtocol ip4Proto = ProtocolManager.getInstance().getProtocol("ip");
            if (ip4Proto == null || dataLength <= 0) {
                return null;
            }
            //创建IP包头默认情况下只需要发送数据长度,下层协议号,接收方ip地址
            HashMap<String, Object> headerInfo = new HashMap<String, Object>();
            headerInfo.put("data_length", dataLength);
            ByteBuffer destIP = ByteBuffer.wrap(this.dest_ip);
            headerInfo.put("destination_ip", destIP.getInt());
            byte protocol = TCPProtocolLayer.TCP_PROTOCOL_NUMBER;
            headerInfo.put("protocol", protocol);
            headerInfo.put("identification", (short)this.port);
            byte[] ipHeader = ip4Proto.createHeader(headerInfo);
            
            return ipHeader;
        }
       
       @Override
        public void handleData(HashMap<String, Object> headerInfo) {
           short src_port = (short)headerInfo.get("src_port");
           System.out.println("receive TCP packet with port:" + src_port);
           boolean ack =  false, syn = false;
           if (headerInfo.get("ACK") != null) {
               System.out.println("it is a ACK packet");
               ack = true;
           }
           if (headerInfo.get("SYN") != null) {
               System.out.println("it is a SYN packet");
               syn = true;
           }
           if (ack && syn) {
               int seq_num = (int)headerInfo.get("seq_num");
               int ack_num = (int)headerInfo.get("ack_num");
               System.out.println("tcp handshake from othersize with seq_num" + seq_num + " and ack_num: " + ack_num);
               this.seq_num = ack_num + 1;
               try {
                createAndSendPacket(null, "ACK");
            } catch (Exception e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
           }
           
       }
    }
    

    应用层对象的主要目标是实现TCP连接的三次握手功能。它首先构造了一个TCP数据包,将SYN控制位打开,然后将数据包发送给目标服务器。然后等待对方回应数据包,一旦本机收到对方回发的ACK数据包后,会将数据包内的相关信息转交给当前应用对象,它解读出对方ACK包中回复的ACK数值后,将该数值加一然后再次构造一个ACK包发送给对方,上面程序运行后通过wireshark抓包可看到如下显示:

    屏幕快照 2019-08-14 下午5.35.18.png

    由此可见,我们成功的完成了TCP协议连接时的三次握手功能,上图显示中有一个数据包设置了RST标志位,它表示重置连接,这个数据包其实不是我们的应用对象发送,很可能是我们绕过了系统网络层发送数据包,当对方数据包回来时,操作系统的网络层发现接收对象没有在它内部不存在,于是自己构造了一个RST数据包发回给对方。

    更详细的讲解和代码调试演示过程,请点击链接

    更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:


    这里写图片描述 新书上架,请诸位朋友多多支持: WechatIMG1.jpeg

    相关文章

      网友评论

        本文标题:代码实现TCP三次握手:程序实现

        本文链接:https://www.haomeiwen.com/subject/jydnjctx.html