美文网首页dubbo
基于Netty的简单RPC案例

基于Netty的简单RPC案例

作者: 晴天哥_王志 | 来源:发表于2020-05-17 18:18 被阅读0次

开篇

  • 这篇文章的目的是想手动实现下RPC框架的思想,一是加深Dubbo的RPC框架的理解,二是实践下Netty的通用实现。
  • 整个代码参考了github上的开源代码,然后在这个基础上做了小幅修改,通过手写代码加深印象。
  • 在这个简单的RPC框架中包含了RPC核心的组件,包括生产者、消费者、注册中心、通信协议的四大部分。

整体框架

通信协议

public class RpcRequest implements Serializable {

    private static final long serialVersionUID = 5623523207288226903L;

    /**
     * 请求ID
     */
    private String requestId;
    /**
     * 调用class类名
     */
    private String className;
    /**
     * 调用方法名
     */
    private String methodName;
    /**
     * 调用参数类型集合
     */
    private Class<?>[] parameterTypes;
    /**
     * 调用参数集合
     */
    private Object[] parameters;
}
  • RpcRequest代表请求报文,核心字段如上注释所示。
public class RpcResponse implements Serializable {

    private static final long serialVersionUID = -2733031642960312811L;

    /**
     * 响应ID
     */
    private String reponseId;

    /**
     * 异常对象
     */
    private Throwable error;

    /**
     * 响应结果
     */
    private Object result;
}
  • RpcResponse代表响应报文,核心字段如上注释所示。

编解码中心

public class RpcEncoder extends MessageToByteEncoder {

    private Class<?> genericClass;

    public RpcEncoder(Class<?> genericClass) {
        this.genericClass = genericClass;
    }

    @Override
    protected void encode(ChannelHandlerContext channelHandlerContext, Object o, ByteBuf byteBuf) throws Exception {

        if (genericClass.isInstance(o)) {
            byte[] data = SerializationUtil.serialize(o);

            byteBuf.writeInt(data.length);
            byteBuf.writeBytes(data);
        }
    }
}
  • RpcEncoder负责实现MessageToByte的转换,包括RpcRequest/RpcResponse到对应的byte的转换。
public class RpcDecoder extends ByteToMessageDecoder {

    private Class<?> genericClass;

    public RpcDecoder(Class<?> genericClass) {
        this.genericClass = genericClass;
    }

    @Override
    protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {

        if(byteBuf.readableBytes() < 4){
            return;
        }

        byteBuf.markReaderIndex();
        int dataLength = byteBuf.readInt();

        if(dataLength < 0){
            channelHandlerContext.close();

            return;
        }

        if (byteBuf.readableBytes() < dataLength) {
            byteBuf.resetReaderIndex();

            return;
        }

        byte[] data = new byte[dataLength];
        byteBuf.readBytes(data);

        Object obj = SerializationUtil.deSerialize(data, genericClass);
        list.add(obj);
    }
}
  • RpcDecoder负责实现ByteToMessage的转换,包括byte到RpcRequest/RpcResponse的转换。
public class SerializationUtil {

    private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();
    private static Objenesis objenesis = new ObjenesisStd(true);

    private SerializationUtil(){

    }

    private static <T> Schema<T> getSchema(Class<T> cls) {
        Schema<T> schema = (Schema<T>) cachedSchema.get(cls);
        if (null == schema) {
            schema = RuntimeSchema.createFrom(cls);
            if (null != schema) {
                cachedSchema.put(cls, schema);
            }
        }

        return schema;
    }

    /**
     * 序列化
     * @param obj
     * @param <T>
     * @return
     */
    public static <T> byte[] serialize(T obj) {

        Class<T> cls = (Class<T>)obj.getClass();
        LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);

        try {
            Schema<T> schema = getSchema(cls);
            return ProtostuffIOUtil.toByteArray(obj, schema, buffer);
        } catch (Exception e) {
            throw new IllegalArgumentException(e.getMessage(), e);
        } finally {
            buffer.clear();
        }
    }

    /**
     * 反序列化
     * @param data
     * @param cls
     * @param <T>
     * @return
     */
    public static <T> T deSerialize(byte[] data, Class<T> cls) {

        try {
            T message = objenesis.newInstance(cls);
            Schema<T> schema = getSchema(cls);
            ProtostuffIOUtil.mergeFrom(data, message, schema);

            return message;
        } catch (Exception e) {
            throw new IllegalArgumentException(e.getMessage(), e);
        }
    }
}

注册中心

public class ServiceRegistry {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceRegistry.class);

    private String registryAddress;
    private CountDownLatch latch = new CountDownLatch(1);

    public ServiceRegistry(String registryAddress) {
        this.registryAddress = registryAddress;
    }

    public void register(String data){
        if (null != data) {
            ZooKeeper zk = connectServer();
            if (null != zk) {
                createNode(zk, data);
            }
        }
    }

    private ZooKeeper connectServer() {
        ZooKeeper zk = null;
        try {
            zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                @Override
                public void process(WatchedEvent watchedEvent) {
                    if (watchedEvent.getState() == Event.KeeperState.SyncConnected) {
                        latch.countDown();
                    }
                }
            });

            latch.await();
        } catch (Exception e) {
            LOGGER.error("connectServer 连接zk异常 ", e);
        }

        return zk;
    }

    private void createNode(ZooKeeper zooKeeper, String data) {
        try {
            byte[] dataBytes = data.getBytes();
            Stat stat = zooKeeper.exists(Constant.ZK_REGISTRY_ROOT_PATH, false);
            if (null == stat) {
                zooKeeper.create(Constant.ZK_REGISTRY_ROOT_PATH, "".getBytes(), ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
            }

            String path = zooKeeper.create(Constant.ZK_DATA_PATH, dataBytes, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);

            LOGGER.info("createNode 创建节点 {} 成功", path);
        } catch (Exception e) {
            LOGGER.error("createNode 创建节点异常 ", e);
        }
    }
}
  • 注册中心负责把provider的ip:port注册到对应的zk节点,这里的zk节点的命名方式比较随意,没有参考意义。
public class ServiceDiscovery {
    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceDiscovery.class);

    private CountDownLatch latch = new CountDownLatch(1);
    private volatile List<String> dataList = new ArrayList<>();

    /**
     * 注册地址
     */
    private String registryAddress;

    public ServiceDiscovery(String registryAddress) {
        this.registryAddress = registryAddress;

        ZooKeeper zk = connectServer();
        if (null != zk) {
            watchNode(zk);
        }
    }

    public String discovery(){
        String data = null;
        int size = dataList.size();
        if(size > 0){
            int temp = ThreadLocalRandom.current().nextInt(size);
            data = dataList.get(temp);

        }
        return data;
    }

    private ZooKeeper connectServer() {
        ZooKeeper zk = null;

        try {
            zk = new ZooKeeper(this.registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                @Override
                public void process(WatchedEvent watchedEvent) {
                    if(watchedEvent.getState() == Event.KeeperState.SyncConnected){
                        latch.countDown();
                    }
                }
            });
            latch.await();
        } catch (Exception e) {
            LOGGER.error("connectServer 连接zk服务报错", e);
        }

        return zk;
    }

    private void watchNode(final ZooKeeper zk) {
        try {
            List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_ROOT_PATH, new Watcher() {
                @Override
                public void process(WatchedEvent watchedEvent) {
                    if(watchedEvent.getType() == Event.EventType.NodeChildrenChanged){
                        watchNode(zk);
                    }
                }
            });

            List<String> dataList = new ArrayList<>();
            byte[] bytes;
            for (String node : nodeList){
                bytes = zk.getData(Constant.ZK_REGISTRY_ROOT_PATH + "/" + node, false, null);
                dataList.add(new String(bytes));
            }

            this.dataList = dataList;
        } catch (Exception e) {
            LOGGER.error("watchNode 监视zk节点异常", e);
        }
    }
}
  • ServiceDiscovery在启动过程中通过watchNode获取对应的provider的地址。
  • ServiceDiscovery的dataList保存的是服务提供者ip:port信息。
  • ServiceDiscovery通过discovery对外提供provider信息的查询。

生产者

public class RpcServer implements ApplicationContextAware, InitializingBean {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcServer.class);

    // 保存server地址
    private String serverAddress;

    // 保存注册中心
    private ServiceRegistry serviceRegistry;

    // 保存接口和服务对象的映射关系
    private Map<String, Object> handlerMap = new HashMap<>();

    public RpcServer(String serverAddress, ServiceRegistry serviceRegistry) {
        this.serverAddress = serverAddress;
        this.serviceRegistry = serviceRegistry;
    }

    @Override
    public void afterPropertiesSet() throws Exception {

        EventLoopGroup masterGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();

        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(masterGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel socketChannel) throws Exception {
                            socketChannel.pipeline()
                                    .addLast(new RpcDecoder(RpcRequest.class))
                                    .addLast(new RpcEncoder(RpcResponse.class))
                                    .addLast(new RpcHandler(handlerMap));
                        }
                    })
                    .option(ChannelOption.SO_BACKLOG, 1024)
                    .childOption(ChannelOption.SO_KEEPALIVE, true);

            //解析IP地址和端口信息
            String[] array = serverAddress.split(":");
            String host = array[0];
            int port = Integer.parseInt(array[1]);

            //启动RPC服务端
            ChannelFuture channelFuture = bootstrap.bind(host, port).sync();
            LOGGER.debug("server started on port: {}", port);

            if(null != serviceRegistry){
                //注册服务地址
                serviceRegistry.register(serverAddress);
                LOGGER.debug("register service:{}", serverAddress);
            }

            //关闭RPC服务器
            channelFuture.channel().closeFuture().sync();
        } catch (Exception e) {

        } finally {
            workerGroup.shutdownGracefully();
            masterGroup.shutdownGracefully();
        }
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {

        Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
        if (MapUtils.isNotEmpty(serviceBeanMap)) {
            for (Object serviceBean : serviceBeanMap.values()) {
                String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
                handlerMap.put(interfaceName, serviceBean);
            }
        }
    }
}
  • handlerMap保存的是Interface => serviceBean的映射关系。
  • RpcServer通过Netty来创建Server对象实现服务监听,通过绑定RpcDecoder、RpcEncoder、RpcHandler来实现消息的编解码和实际操作。
public class RpcHandler extends SimpleChannelInboundHandler<RpcRequest> {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcHandler.class);

    private final Map<String, Object> handlerMap;

    public RpcHandler(Map<String, Object> handlerMap) {
        this.handlerMap = handlerMap;
    }


    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, RpcRequest rpcRequest) throws Exception {

        RpcResponse rpcResponse = new RpcResponse();
        rpcResponse.setReponseId(rpcRequest.getRequestId());

        try {

            Object result = this.handle(rpcRequest);
            rpcResponse.setResult(result);
        } catch (Exception e) {
            rpcResponse.setError(e);
        }

        channelHandlerContext.writeAndFlush(rpcResponse).addListener(ChannelFutureListener.CLOSE);

    }

    private Object handle(RpcRequest request) throws InvocationTargetException {
        String className = request.getClassName();
        Object serviceBean = this.handlerMap.get(className);

        String methodName = request.getMethodName();
        Class<?>[] parameterTypes = request.getParameterTypes();
        Object[] parameters = request.getParameters();

        Class<?> serviceClass = serviceBean.getClass();
        FastClass serviceFastClass = FastClass.create(serviceClass);
        FastMethod serviceFastMethod = serviceFastClass.getMethod(methodName, parameterTypes);

        return serviceFastMethod.invoke(serviceBean, parameters);
    }
}
  • RpcHandler#channelRead0通过Netty提供的基础功能负责读取byte报文。
  • RpcHandler#handle内部负责根据RpcRequest指定的参数找到对应的serviceBean对象,并通过cglib的FastMethod#invoke的反射调用实现服务调用。

消费者

public class RpcClient extends SimpleChannelInboundHandler<RpcResponse> {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class);

    /**
     * 主机名
     */
    private String host;
    /**
     * 端口
     */
    private int port;
    /**
     * RPC响应对象
     */
    private RpcResponse response;

    private final Object obj = new Object();

    public RpcClient(String host, int port) {
        this.host = host;
        this.port = port;
    }

    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, RpcResponse rpcResponse) throws Exception {

        this.response = rpcResponse;

        synchronized (obj) {
            obj.notifyAll();
        }
    }

    public RpcResponse send(RpcRequest rpcRequest) throws Exception {
        EventLoopGroup group = new NioEventLoopGroup();

        try {
            Bootstrap bootstrap = new Bootstrap();
            bootstrap.group(group)
                    .channel(NioSocketChannel.class)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel socketChannel) throws Exception {
                            socketChannel.pipeline()
                                    .addLast(new RpcEncoder(RpcRequest.class))
                                    .addLast(new RpcDecoder(RpcResponse.class))
                                    .addLast(RpcClient.this);
                        }
                    })
                    .option(ChannelOption.SO_KEEPALIVE, true);

            ChannelFuture future = bootstrap.connect(host, port).sync();

            future.channel().writeAndFlush(rpcRequest).sync();

            synchronized (obj){
                //未收到响应,使线程继续等待
                obj.wait();
            }

            if(null != response){
                //关闭RPC请求连接
                future.channel().closeFuture().sync();
            }

            return response;
        } catch (Exception e) {
            LOGGER.error("RpcClient 异常 ", e);
        } finally {
            group.shutdownGracefully();
        }

        return null;
    }
}
  • channelRead0负责读取响应报文。
  • send内部通过Netty创建Client功能负责发送报文即可。
public class RpcProxy {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcProxy.class);

    private String serverAddress;

    private ServiceDiscovery serviceDiscovery;

    public RpcProxy(ServiceDiscovery serviceDiscovery) {
        this.serviceDiscovery = serviceDiscovery;
    }

    public <T> T create(Class<?> interfaceClass) {

        return (T) Proxy.newProxyInstance(
                interfaceClass.getClassLoader(),
                new Class<?>[]{interfaceClass},
                new InvocationHandler() {
                    @Override
                    public Object invoke(Object o, Method method, Object[] objects) throws Throwable {

                        RpcRequest request = new RpcRequest();
                        request.setRequestId(UUID.randomUUID().toString());
                        request.setClassName(method.getDeclaringClass().getName());
                        request.setMethodName(method.getName());
                        request.setParameterTypes(method.getParameterTypes());
                        request.setParameters(objects);

                        if(null != serviceDiscovery){
                            //发现服务
                            serverAddress = serviceDiscovery.discovery();
                        }

                        if(serverAddress == null){
                            throw new RuntimeException("serverAddress is null...");
                        }

                        String[] array = serverAddress.split(":");
                        String host = array[0];
                        int port = Integer.parseInt(array[1]);

                        RpcClient client = new RpcClient(host, port);

                        long startTime = System.currentTimeMillis();
                        //通过RPC客户端发送rpc请求并且获取rpc响应
                        RpcResponse response = client.send(request);
                        LOGGER.debug("send rpc request elapsed time: {}ms...", System.currentTimeMillis() - startTime);

                        if (response == null) {
                            throw new RuntimeException("response is null...");
                        }

                        //返回RPC响应结果
                        if(response.hasError()){
                            throw response.getError();
                        }else {
                            return response.getResult();
                        }
                    }
                });
    }
}
  • RpcProxy内部通过cglib来实现consumer的代理功能。

调用代码

public class RpcProvider {
    public static void main(String[] args) {
        new ClassPathXmlApplicationContext("provider/application-provider.xml");
    }
}



public class RpcConsumer {

    public static void main(String[] args) {
        ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext("consumer/application-consumer.xml");

        HelloService helloService = context.getBean(RpcProxy.class).create(HelloService.class);
        String result = helloService.sayHello("sun boy");

        System.out.println(result);
    }
}

github源码

相关文章

网友评论

    本文标题:基于Netty的简单RPC案例

    本文链接:https://www.haomeiwen.com/subject/wwmsohtx.html