美文网首页dubbo
基于Netty的简单RPC案例

基于Netty的简单RPC案例

作者: 晴天哥_王志 | 来源:发表于2020-05-17 18:18 被阅读0次

    开篇

    • 这篇文章的目的是想手动实现下RPC框架的思想,一是加深Dubbo的RPC框架的理解,二是实践下Netty的通用实现。
    • 整个代码参考了github上的开源代码,然后在这个基础上做了小幅修改,通过手写代码加深印象。
    • 在这个简单的RPC框架中包含了RPC核心的组件,包括生产者、消费者、注册中心、通信协议的四大部分。

    整体框架

    通信协议

    public class RpcRequest implements Serializable {
    
        private static final long serialVersionUID = 5623523207288226903L;
    
        /**
         * 请求ID
         */
        private String requestId;
        /**
         * 调用class类名
         */
        private String className;
        /**
         * 调用方法名
         */
        private String methodName;
        /**
         * 调用参数类型集合
         */
        private Class<?>[] parameterTypes;
        /**
         * 调用参数集合
         */
        private Object[] parameters;
    }
    
    • RpcRequest代表请求报文,核心字段如上注释所示。
    public class RpcResponse implements Serializable {
    
        private static final long serialVersionUID = -2733031642960312811L;
    
        /**
         * 响应ID
         */
        private String reponseId;
    
        /**
         * 异常对象
         */
        private Throwable error;
    
        /**
         * 响应结果
         */
        private Object result;
    }
    
    • RpcResponse代表响应报文,核心字段如上注释所示。

    编解码中心

    public class RpcEncoder extends MessageToByteEncoder {
    
        private Class<?> genericClass;
    
        public RpcEncoder(Class<?> genericClass) {
            this.genericClass = genericClass;
        }
    
        @Override
        protected void encode(ChannelHandlerContext channelHandlerContext, Object o, ByteBuf byteBuf) throws Exception {
    
            if (genericClass.isInstance(o)) {
                byte[] data = SerializationUtil.serialize(o);
    
                byteBuf.writeInt(data.length);
                byteBuf.writeBytes(data);
            }
        }
    }
    
    • RpcEncoder负责实现MessageToByte的转换,包括RpcRequest/RpcResponse到对应的byte的转换。
    public class RpcDecoder extends ByteToMessageDecoder {
    
        private Class<?> genericClass;
    
        public RpcDecoder(Class<?> genericClass) {
            this.genericClass = genericClass;
        }
    
        @Override
        protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
    
            if(byteBuf.readableBytes() < 4){
                return;
            }
    
            byteBuf.markReaderIndex();
            int dataLength = byteBuf.readInt();
    
            if(dataLength < 0){
                channelHandlerContext.close();
    
                return;
            }
    
            if (byteBuf.readableBytes() < dataLength) {
                byteBuf.resetReaderIndex();
    
                return;
            }
    
            byte[] data = new byte[dataLength];
            byteBuf.readBytes(data);
    
            Object obj = SerializationUtil.deSerialize(data, genericClass);
            list.add(obj);
        }
    }
    
    • RpcDecoder负责实现ByteToMessage的转换,包括byte到RpcRequest/RpcResponse的转换。
    public class SerializationUtil {
    
        private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();
        private static Objenesis objenesis = new ObjenesisStd(true);
    
        private SerializationUtil(){
    
        }
    
        private static <T> Schema<T> getSchema(Class<T> cls) {
            Schema<T> schema = (Schema<T>) cachedSchema.get(cls);
            if (null == schema) {
                schema = RuntimeSchema.createFrom(cls);
                if (null != schema) {
                    cachedSchema.put(cls, schema);
                }
            }
    
            return schema;
        }
    
        /**
         * 序列化
         * @param obj
         * @param <T>
         * @return
         */
        public static <T> byte[] serialize(T obj) {
    
            Class<T> cls = (Class<T>)obj.getClass();
            LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
    
            try {
                Schema<T> schema = getSchema(cls);
                return ProtostuffIOUtil.toByteArray(obj, schema, buffer);
            } catch (Exception e) {
                throw new IllegalArgumentException(e.getMessage(), e);
            } finally {
                buffer.clear();
            }
        }
    
        /**
         * 反序列化
         * @param data
         * @param cls
         * @param <T>
         * @return
         */
        public static <T> T deSerialize(byte[] data, Class<T> cls) {
    
            try {
                T message = objenesis.newInstance(cls);
                Schema<T> schema = getSchema(cls);
                ProtostuffIOUtil.mergeFrom(data, message, schema);
    
                return message;
            } catch (Exception e) {
                throw new IllegalArgumentException(e.getMessage(), e);
            }
        }
    }
    

    注册中心

    public class ServiceRegistry {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(ServiceRegistry.class);
    
        private String registryAddress;
        private CountDownLatch latch = new CountDownLatch(1);
    
        public ServiceRegistry(String registryAddress) {
            this.registryAddress = registryAddress;
        }
    
        public void register(String data){
            if (null != data) {
                ZooKeeper zk = connectServer();
                if (null != zk) {
                    createNode(zk, data);
                }
            }
        }
    
        private ZooKeeper connectServer() {
            ZooKeeper zk = null;
            try {
                zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                    @Override
                    public void process(WatchedEvent watchedEvent) {
                        if (watchedEvent.getState() == Event.KeeperState.SyncConnected) {
                            latch.countDown();
                        }
                    }
                });
    
                latch.await();
            } catch (Exception e) {
                LOGGER.error("connectServer 连接zk异常 ", e);
            }
    
            return zk;
        }
    
        private void createNode(ZooKeeper zooKeeper, String data) {
            try {
                byte[] dataBytes = data.getBytes();
                Stat stat = zooKeeper.exists(Constant.ZK_REGISTRY_ROOT_PATH, false);
                if (null == stat) {
                    zooKeeper.create(Constant.ZK_REGISTRY_ROOT_PATH, "".getBytes(), ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
                }
    
                String path = zooKeeper.create(Constant.ZK_DATA_PATH, dataBytes, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
    
                LOGGER.info("createNode 创建节点 {} 成功", path);
            } catch (Exception e) {
                LOGGER.error("createNode 创建节点异常 ", e);
            }
        }
    }
    
    • 注册中心负责把provider的ip:port注册到对应的zk节点,这里的zk节点的命名方式比较随意,没有参考意义。
    public class ServiceDiscovery {
        private static final Logger LOGGER = LoggerFactory.getLogger(ServiceDiscovery.class);
    
        private CountDownLatch latch = new CountDownLatch(1);
        private volatile List<String> dataList = new ArrayList<>();
    
        /**
         * 注册地址
         */
        private String registryAddress;
    
        public ServiceDiscovery(String registryAddress) {
            this.registryAddress = registryAddress;
    
            ZooKeeper zk = connectServer();
            if (null != zk) {
                watchNode(zk);
            }
        }
    
        public String discovery(){
            String data = null;
            int size = dataList.size();
            if(size > 0){
                int temp = ThreadLocalRandom.current().nextInt(size);
                data = dataList.get(temp);
    
            }
            return data;
        }
    
        private ZooKeeper connectServer() {
            ZooKeeper zk = null;
    
            try {
                zk = new ZooKeeper(this.registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                    @Override
                    public void process(WatchedEvent watchedEvent) {
                        if(watchedEvent.getState() == Event.KeeperState.SyncConnected){
                            latch.countDown();
                        }
                    }
                });
                latch.await();
            } catch (Exception e) {
                LOGGER.error("connectServer 连接zk服务报错", e);
            }
    
            return zk;
        }
    
        private void watchNode(final ZooKeeper zk) {
            try {
                List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_ROOT_PATH, new Watcher() {
                    @Override
                    public void process(WatchedEvent watchedEvent) {
                        if(watchedEvent.getType() == Event.EventType.NodeChildrenChanged){
                            watchNode(zk);
                        }
                    }
                });
    
                List<String> dataList = new ArrayList<>();
                byte[] bytes;
                for (String node : nodeList){
                    bytes = zk.getData(Constant.ZK_REGISTRY_ROOT_PATH + "/" + node, false, null);
                    dataList.add(new String(bytes));
                }
    
                this.dataList = dataList;
            } catch (Exception e) {
                LOGGER.error("watchNode 监视zk节点异常", e);
            }
        }
    }
    
    • ServiceDiscovery在启动过程中通过watchNode获取对应的provider的地址。
    • ServiceDiscovery的dataList保存的是服务提供者ip:port信息。
    • ServiceDiscovery通过discovery对外提供provider信息的查询。

    生产者

    public class RpcServer implements ApplicationContextAware, InitializingBean {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(RpcServer.class);
    
        // 保存server地址
        private String serverAddress;
    
        // 保存注册中心
        private ServiceRegistry serviceRegistry;
    
        // 保存接口和服务对象的映射关系
        private Map<String, Object> handlerMap = new HashMap<>();
    
        public RpcServer(String serverAddress, ServiceRegistry serviceRegistry) {
            this.serverAddress = serverAddress;
            this.serviceRegistry = serviceRegistry;
        }
    
        @Override
        public void afterPropertiesSet() throws Exception {
    
            EventLoopGroup masterGroup = new NioEventLoopGroup();
            EventLoopGroup workerGroup = new NioEventLoopGroup();
    
            try {
                ServerBootstrap bootstrap = new ServerBootstrap();
                bootstrap.group(masterGroup, workerGroup)
                        .channel(NioServerSocketChannel.class)
                        .childHandler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel socketChannel) throws Exception {
                                socketChannel.pipeline()
                                        .addLast(new RpcDecoder(RpcRequest.class))
                                        .addLast(new RpcEncoder(RpcResponse.class))
                                        .addLast(new RpcHandler(handlerMap));
                            }
                        })
                        .option(ChannelOption.SO_BACKLOG, 1024)
                        .childOption(ChannelOption.SO_KEEPALIVE, true);
    
                //解析IP地址和端口信息
                String[] array = serverAddress.split(":");
                String host = array[0];
                int port = Integer.parseInt(array[1]);
    
                //启动RPC服务端
                ChannelFuture channelFuture = bootstrap.bind(host, port).sync();
                LOGGER.debug("server started on port: {}", port);
    
                if(null != serviceRegistry){
                    //注册服务地址
                    serviceRegistry.register(serverAddress);
                    LOGGER.debug("register service:{}", serverAddress);
                }
    
                //关闭RPC服务器
                channelFuture.channel().closeFuture().sync();
            } catch (Exception e) {
    
            } finally {
                workerGroup.shutdownGracefully();
                masterGroup.shutdownGracefully();
            }
        }
    
        @Override
        public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
    
            Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
            if (MapUtils.isNotEmpty(serviceBeanMap)) {
                for (Object serviceBean : serviceBeanMap.values()) {
                    String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
                    handlerMap.put(interfaceName, serviceBean);
                }
            }
        }
    }
    
    • handlerMap保存的是Interface => serviceBean的映射关系。
    • RpcServer通过Netty来创建Server对象实现服务监听,通过绑定RpcDecoder、RpcEncoder、RpcHandler来实现消息的编解码和实际操作。
    public class RpcHandler extends SimpleChannelInboundHandler<RpcRequest> {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(RpcHandler.class);
    
        private final Map<String, Object> handlerMap;
    
        public RpcHandler(Map<String, Object> handlerMap) {
            this.handlerMap = handlerMap;
        }
    
    
        @Override
        protected void channelRead0(ChannelHandlerContext channelHandlerContext, RpcRequest rpcRequest) throws Exception {
    
            RpcResponse rpcResponse = new RpcResponse();
            rpcResponse.setReponseId(rpcRequest.getRequestId());
    
            try {
    
                Object result = this.handle(rpcRequest);
                rpcResponse.setResult(result);
            } catch (Exception e) {
                rpcResponse.setError(e);
            }
    
            channelHandlerContext.writeAndFlush(rpcResponse).addListener(ChannelFutureListener.CLOSE);
    
        }
    
        private Object handle(RpcRequest request) throws InvocationTargetException {
            String className = request.getClassName();
            Object serviceBean = this.handlerMap.get(className);
    
            String methodName = request.getMethodName();
            Class<?>[] parameterTypes = request.getParameterTypes();
            Object[] parameters = request.getParameters();
    
            Class<?> serviceClass = serviceBean.getClass();
            FastClass serviceFastClass = FastClass.create(serviceClass);
            FastMethod serviceFastMethod = serviceFastClass.getMethod(methodName, parameterTypes);
    
            return serviceFastMethod.invoke(serviceBean, parameters);
        }
    }
    
    • RpcHandler#channelRead0通过Netty提供的基础功能负责读取byte报文。
    • RpcHandler#handle内部负责根据RpcRequest指定的参数找到对应的serviceBean对象,并通过cglib的FastMethod#invoke的反射调用实现服务调用。

    消费者

    public class RpcClient extends SimpleChannelInboundHandler<RpcResponse> {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class);
    
        /**
         * 主机名
         */
        private String host;
        /**
         * 端口
         */
        private int port;
        /**
         * RPC响应对象
         */
        private RpcResponse response;
    
        private final Object obj = new Object();
    
        public RpcClient(String host, int port) {
            this.host = host;
            this.port = port;
        }
    
        @Override
        protected void channelRead0(ChannelHandlerContext channelHandlerContext, RpcResponse rpcResponse) throws Exception {
    
            this.response = rpcResponse;
    
            synchronized (obj) {
                obj.notifyAll();
            }
        }
    
        public RpcResponse send(RpcRequest rpcRequest) throws Exception {
            EventLoopGroup group = new NioEventLoopGroup();
    
            try {
                Bootstrap bootstrap = new Bootstrap();
                bootstrap.group(group)
                        .channel(NioSocketChannel.class)
                        .handler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel socketChannel) throws Exception {
                                socketChannel.pipeline()
                                        .addLast(new RpcEncoder(RpcRequest.class))
                                        .addLast(new RpcDecoder(RpcResponse.class))
                                        .addLast(RpcClient.this);
                            }
                        })
                        .option(ChannelOption.SO_KEEPALIVE, true);
    
                ChannelFuture future = bootstrap.connect(host, port).sync();
    
                future.channel().writeAndFlush(rpcRequest).sync();
    
                synchronized (obj){
                    //未收到响应,使线程继续等待
                    obj.wait();
                }
    
                if(null != response){
                    //关闭RPC请求连接
                    future.channel().closeFuture().sync();
                }
    
                return response;
            } catch (Exception e) {
                LOGGER.error("RpcClient 异常 ", e);
            } finally {
                group.shutdownGracefully();
            }
    
            return null;
        }
    }
    
    • channelRead0负责读取响应报文。
    • send内部通过Netty创建Client功能负责发送报文即可。
    public class RpcProxy {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(RpcProxy.class);
    
        private String serverAddress;
    
        private ServiceDiscovery serviceDiscovery;
    
        public RpcProxy(ServiceDiscovery serviceDiscovery) {
            this.serviceDiscovery = serviceDiscovery;
        }
    
        public <T> T create(Class<?> interfaceClass) {
    
            return (T) Proxy.newProxyInstance(
                    interfaceClass.getClassLoader(),
                    new Class<?>[]{interfaceClass},
                    new InvocationHandler() {
                        @Override
                        public Object invoke(Object o, Method method, Object[] objects) throws Throwable {
    
                            RpcRequest request = new RpcRequest();
                            request.setRequestId(UUID.randomUUID().toString());
                            request.setClassName(method.getDeclaringClass().getName());
                            request.setMethodName(method.getName());
                            request.setParameterTypes(method.getParameterTypes());
                            request.setParameters(objects);
    
                            if(null != serviceDiscovery){
                                //发现服务
                                serverAddress = serviceDiscovery.discovery();
                            }
    
                            if(serverAddress == null){
                                throw new RuntimeException("serverAddress is null...");
                            }
    
                            String[] array = serverAddress.split(":");
                            String host = array[0];
                            int port = Integer.parseInt(array[1]);
    
                            RpcClient client = new RpcClient(host, port);
    
                            long startTime = System.currentTimeMillis();
                            //通过RPC客户端发送rpc请求并且获取rpc响应
                            RpcResponse response = client.send(request);
                            LOGGER.debug("send rpc request elapsed time: {}ms...", System.currentTimeMillis() - startTime);
    
                            if (response == null) {
                                throw new RuntimeException("response is null...");
                            }
    
                            //返回RPC响应结果
                            if(response.hasError()){
                                throw response.getError();
                            }else {
                                return response.getResult();
                            }
                        }
                    });
        }
    }
    
    • RpcProxy内部通过cglib来实现consumer的代理功能。

    调用代码

    public class RpcProvider {
        public static void main(String[] args) {
            new ClassPathXmlApplicationContext("provider/application-provider.xml");
        }
    }
    
    
    
    public class RpcConsumer {
    
        public static void main(String[] args) {
            ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext("consumer/application-consumer.xml");
    
            HelloService helloService = context.getBean(RpcProxy.class).create(HelloService.class);
            String result = helloService.sayHello("sun boy");
    
            System.out.println(result);
        }
    }
    

    github源码

    相关文章

      网友评论

        本文标题:基于Netty的简单RPC案例

        本文链接:https://www.haomeiwen.com/subject/wwmsohtx.html