开篇
- 这篇文章的目的是想手动实现下RPC框架的思想,一是加深Dubbo的RPC框架的理解,二是实践下Netty的通用实现。
- 整个代码参考了github上的开源代码,然后在这个基础上做了小幅修改,通过手写代码加深印象。
- 在这个简单的RPC框架中包含了RPC核心的组件,包括生产者、消费者、注册中心、通信协议的四大部分。
整体框架
通信协议
public class RpcRequest implements Serializable {
private static final long serialVersionUID = 5623523207288226903L;
/**
* 请求ID
*/
private String requestId;
/**
* 调用class类名
*/
private String className;
/**
* 调用方法名
*/
private String methodName;
/**
* 调用参数类型集合
*/
private Class<?>[] parameterTypes;
/**
* 调用参数集合
*/
private Object[] parameters;
}
- RpcRequest代表请求报文,核心字段如上注释所示。
public class RpcResponse implements Serializable {
private static final long serialVersionUID = -2733031642960312811L;
/**
* 响应ID
*/
private String reponseId;
/**
* 异常对象
*/
private Throwable error;
/**
* 响应结果
*/
private Object result;
}
- RpcResponse代表响应报文,核心字段如上注释所示。
编解码中心
public class RpcEncoder extends MessageToByteEncoder {
private Class<?> genericClass;
public RpcEncoder(Class<?> genericClass) {
this.genericClass = genericClass;
}
@Override
protected void encode(ChannelHandlerContext channelHandlerContext, Object o, ByteBuf byteBuf) throws Exception {
if (genericClass.isInstance(o)) {
byte[] data = SerializationUtil.serialize(o);
byteBuf.writeInt(data.length);
byteBuf.writeBytes(data);
}
}
}
- RpcEncoder负责实现MessageToByte的转换,包括RpcRequest/RpcResponse到对应的byte的转换。
public class RpcDecoder extends ByteToMessageDecoder {
private Class<?> genericClass;
public RpcDecoder(Class<?> genericClass) {
this.genericClass = genericClass;
}
@Override
protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
if(byteBuf.readableBytes() < 4){
return;
}
byteBuf.markReaderIndex();
int dataLength = byteBuf.readInt();
if(dataLength < 0){
channelHandlerContext.close();
return;
}
if (byteBuf.readableBytes() < dataLength) {
byteBuf.resetReaderIndex();
return;
}
byte[] data = new byte[dataLength];
byteBuf.readBytes(data);
Object obj = SerializationUtil.deSerialize(data, genericClass);
list.add(obj);
}
}
- RpcDecoder负责实现ByteToMessage的转换,包括byte到RpcRequest/RpcResponse的转换。
public class SerializationUtil {
private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();
private static Objenesis objenesis = new ObjenesisStd(true);
private SerializationUtil(){
}
private static <T> Schema<T> getSchema(Class<T> cls) {
Schema<T> schema = (Schema<T>) cachedSchema.get(cls);
if (null == schema) {
schema = RuntimeSchema.createFrom(cls);
if (null != schema) {
cachedSchema.put(cls, schema);
}
}
return schema;
}
/**
* 序列化
* @param obj
* @param <T>
* @return
*/
public static <T> byte[] serialize(T obj) {
Class<T> cls = (Class<T>)obj.getClass();
LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
try {
Schema<T> schema = getSchema(cls);
return ProtostuffIOUtil.toByteArray(obj, schema, buffer);
} catch (Exception e) {
throw new IllegalArgumentException(e.getMessage(), e);
} finally {
buffer.clear();
}
}
/**
* 反序列化
* @param data
* @param cls
* @param <T>
* @return
*/
public static <T> T deSerialize(byte[] data, Class<T> cls) {
try {
T message = objenesis.newInstance(cls);
Schema<T> schema = getSchema(cls);
ProtostuffIOUtil.mergeFrom(data, message, schema);
return message;
} catch (Exception e) {
throw new IllegalArgumentException(e.getMessage(), e);
}
}
}
注册中心
public class ServiceRegistry {
private static final Logger LOGGER = LoggerFactory.getLogger(ServiceRegistry.class);
private String registryAddress;
private CountDownLatch latch = new CountDownLatch(1);
public ServiceRegistry(String registryAddress) {
this.registryAddress = registryAddress;
}
public void register(String data){
if (null != data) {
ZooKeeper zk = connectServer();
if (null != zk) {
createNode(zk, data);
}
}
}
private ZooKeeper connectServer() {
ZooKeeper zk = null;
try {
zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
@Override
public void process(WatchedEvent watchedEvent) {
if (watchedEvent.getState() == Event.KeeperState.SyncConnected) {
latch.countDown();
}
}
});
latch.await();
} catch (Exception e) {
LOGGER.error("connectServer 连接zk异常 ", e);
}
return zk;
}
private void createNode(ZooKeeper zooKeeper, String data) {
try {
byte[] dataBytes = data.getBytes();
Stat stat = zooKeeper.exists(Constant.ZK_REGISTRY_ROOT_PATH, false);
if (null == stat) {
zooKeeper.create(Constant.ZK_REGISTRY_ROOT_PATH, "".getBytes(), ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
}
String path = zooKeeper.create(Constant.ZK_DATA_PATH, dataBytes, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
LOGGER.info("createNode 创建节点 {} 成功", path);
} catch (Exception e) {
LOGGER.error("createNode 创建节点异常 ", e);
}
}
}
- 注册中心负责把provider的ip:port注册到对应的zk节点,这里的zk节点的命名方式比较随意,没有参考意义。
public class ServiceDiscovery {
private static final Logger LOGGER = LoggerFactory.getLogger(ServiceDiscovery.class);
private CountDownLatch latch = new CountDownLatch(1);
private volatile List<String> dataList = new ArrayList<>();
/**
* 注册地址
*/
private String registryAddress;
public ServiceDiscovery(String registryAddress) {
this.registryAddress = registryAddress;
ZooKeeper zk = connectServer();
if (null != zk) {
watchNode(zk);
}
}
public String discovery(){
String data = null;
int size = dataList.size();
if(size > 0){
int temp = ThreadLocalRandom.current().nextInt(size);
data = dataList.get(temp);
}
return data;
}
private ZooKeeper connectServer() {
ZooKeeper zk = null;
try {
zk = new ZooKeeper(this.registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
@Override
public void process(WatchedEvent watchedEvent) {
if(watchedEvent.getState() == Event.KeeperState.SyncConnected){
latch.countDown();
}
}
});
latch.await();
} catch (Exception e) {
LOGGER.error("connectServer 连接zk服务报错", e);
}
return zk;
}
private void watchNode(final ZooKeeper zk) {
try {
List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_ROOT_PATH, new Watcher() {
@Override
public void process(WatchedEvent watchedEvent) {
if(watchedEvent.getType() == Event.EventType.NodeChildrenChanged){
watchNode(zk);
}
}
});
List<String> dataList = new ArrayList<>();
byte[] bytes;
for (String node : nodeList){
bytes = zk.getData(Constant.ZK_REGISTRY_ROOT_PATH + "/" + node, false, null);
dataList.add(new String(bytes));
}
this.dataList = dataList;
} catch (Exception e) {
LOGGER.error("watchNode 监视zk节点异常", e);
}
}
}
- ServiceDiscovery在启动过程中通过watchNode获取对应的provider的地址。
- ServiceDiscovery的dataList保存的是服务提供者ip:port信息。
- ServiceDiscovery通过discovery对外提供provider信息的查询。
生产者
public class RpcServer implements ApplicationContextAware, InitializingBean {
private static final Logger LOGGER = LoggerFactory.getLogger(RpcServer.class);
// 保存server地址
private String serverAddress;
// 保存注册中心
private ServiceRegistry serviceRegistry;
// 保存接口和服务对象的映射关系
private Map<String, Object> handlerMap = new HashMap<>();
public RpcServer(String serverAddress, ServiceRegistry serviceRegistry) {
this.serverAddress = serverAddress;
this.serviceRegistry = serviceRegistry;
}
@Override
public void afterPropertiesSet() throws Exception {
EventLoopGroup masterGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(masterGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel socketChannel) throws Exception {
socketChannel.pipeline()
.addLast(new RpcDecoder(RpcRequest.class))
.addLast(new RpcEncoder(RpcResponse.class))
.addLast(new RpcHandler(handlerMap));
}
})
.option(ChannelOption.SO_BACKLOG, 1024)
.childOption(ChannelOption.SO_KEEPALIVE, true);
//解析IP地址和端口信息
String[] array = serverAddress.split(":");
String host = array[0];
int port = Integer.parseInt(array[1]);
//启动RPC服务端
ChannelFuture channelFuture = bootstrap.bind(host, port).sync();
LOGGER.debug("server started on port: {}", port);
if(null != serviceRegistry){
//注册服务地址
serviceRegistry.register(serverAddress);
LOGGER.debug("register service:{}", serverAddress);
}
//关闭RPC服务器
channelFuture.channel().closeFuture().sync();
} catch (Exception e) {
} finally {
workerGroup.shutdownGracefully();
masterGroup.shutdownGracefully();
}
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
if (MapUtils.isNotEmpty(serviceBeanMap)) {
for (Object serviceBean : serviceBeanMap.values()) {
String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
handlerMap.put(interfaceName, serviceBean);
}
}
}
}
- handlerMap保存的是Interface => serviceBean的映射关系。
- RpcServer通过Netty来创建Server对象实现服务监听,通过绑定RpcDecoder、RpcEncoder、RpcHandler来实现消息的编解码和实际操作。
public class RpcHandler extends SimpleChannelInboundHandler<RpcRequest> {
private static final Logger LOGGER = LoggerFactory.getLogger(RpcHandler.class);
private final Map<String, Object> handlerMap;
public RpcHandler(Map<String, Object> handlerMap) {
this.handlerMap = handlerMap;
}
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, RpcRequest rpcRequest) throws Exception {
RpcResponse rpcResponse = new RpcResponse();
rpcResponse.setReponseId(rpcRequest.getRequestId());
try {
Object result = this.handle(rpcRequest);
rpcResponse.setResult(result);
} catch (Exception e) {
rpcResponse.setError(e);
}
channelHandlerContext.writeAndFlush(rpcResponse).addListener(ChannelFutureListener.CLOSE);
}
private Object handle(RpcRequest request) throws InvocationTargetException {
String className = request.getClassName();
Object serviceBean = this.handlerMap.get(className);
String methodName = request.getMethodName();
Class<?>[] parameterTypes = request.getParameterTypes();
Object[] parameters = request.getParameters();
Class<?> serviceClass = serviceBean.getClass();
FastClass serviceFastClass = FastClass.create(serviceClass);
FastMethod serviceFastMethod = serviceFastClass.getMethod(methodName, parameterTypes);
return serviceFastMethod.invoke(serviceBean, parameters);
}
}
- RpcHandler#channelRead0通过Netty提供的基础功能负责读取byte报文。
- RpcHandler#handle内部负责根据RpcRequest指定的参数找到对应的serviceBean对象,并通过cglib的FastMethod#invoke的反射调用实现服务调用。
消费者
public class RpcClient extends SimpleChannelInboundHandler<RpcResponse> {
private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class);
/**
* 主机名
*/
private String host;
/**
* 端口
*/
private int port;
/**
* RPC响应对象
*/
private RpcResponse response;
private final Object obj = new Object();
public RpcClient(String host, int port) {
this.host = host;
this.port = port;
}
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, RpcResponse rpcResponse) throws Exception {
this.response = rpcResponse;
synchronized (obj) {
obj.notifyAll();
}
}
public RpcResponse send(RpcRequest rpcRequest) throws Exception {
EventLoopGroup group = new NioEventLoopGroup();
try {
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel socketChannel) throws Exception {
socketChannel.pipeline()
.addLast(new RpcEncoder(RpcRequest.class))
.addLast(new RpcDecoder(RpcResponse.class))
.addLast(RpcClient.this);
}
})
.option(ChannelOption.SO_KEEPALIVE, true);
ChannelFuture future = bootstrap.connect(host, port).sync();
future.channel().writeAndFlush(rpcRequest).sync();
synchronized (obj){
//未收到响应,使线程继续等待
obj.wait();
}
if(null != response){
//关闭RPC请求连接
future.channel().closeFuture().sync();
}
return response;
} catch (Exception e) {
LOGGER.error("RpcClient 异常 ", e);
} finally {
group.shutdownGracefully();
}
return null;
}
}
- channelRead0负责读取响应报文。
- send内部通过Netty创建Client功能负责发送报文即可。
public class RpcProxy {
private static final Logger LOGGER = LoggerFactory.getLogger(RpcProxy.class);
private String serverAddress;
private ServiceDiscovery serviceDiscovery;
public RpcProxy(ServiceDiscovery serviceDiscovery) {
this.serviceDiscovery = serviceDiscovery;
}
public <T> T create(Class<?> interfaceClass) {
return (T) Proxy.newProxyInstance(
interfaceClass.getClassLoader(),
new Class<?>[]{interfaceClass},
new InvocationHandler() {
@Override
public Object invoke(Object o, Method method, Object[] objects) throws Throwable {
RpcRequest request = new RpcRequest();
request.setRequestId(UUID.randomUUID().toString());
request.setClassName(method.getDeclaringClass().getName());
request.setMethodName(method.getName());
request.setParameterTypes(method.getParameterTypes());
request.setParameters(objects);
if(null != serviceDiscovery){
//发现服务
serverAddress = serviceDiscovery.discovery();
}
if(serverAddress == null){
throw new RuntimeException("serverAddress is null...");
}
String[] array = serverAddress.split(":");
String host = array[0];
int port = Integer.parseInt(array[1]);
RpcClient client = new RpcClient(host, port);
long startTime = System.currentTimeMillis();
//通过RPC客户端发送rpc请求并且获取rpc响应
RpcResponse response = client.send(request);
LOGGER.debug("send rpc request elapsed time: {}ms...", System.currentTimeMillis() - startTime);
if (response == null) {
throw new RuntimeException("response is null...");
}
//返回RPC响应结果
if(response.hasError()){
throw response.getError();
}else {
return response.getResult();
}
}
});
}
}
- RpcProxy内部通过cglib来实现consumer的代理功能。
调用代码
public class RpcProvider {
public static void main(String[] args) {
new ClassPathXmlApplicationContext("provider/application-provider.xml");
}
}
public class RpcConsumer {
public static void main(String[] args) {
ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext("consumer/application-consumer.xml");
HelloService helloService = context.getBean(RpcProxy.class).create(HelloService.class);
String result = helloService.sayHello("sun boy");
System.out.println(result);
}
}
github源码
网友评论