使用技术Netty,spring,自定义注解,反向代理完成RPC框架
架构
common 通过方法,包含编码,请求响应,序列化
registry 负责zookeeper节点注册,发现
servcie 服务提供,包含实体,及提供的各种方法
server 服务端框架
cilent 客户端框架
sample server 通过服务端框架创建服务
sample cilent 通过客户端框架消费服务
![](https://img.haomeiwen.com/i2670708/9374977cd6db260b.png)
Regisry
1.pom
<dependency>
<groupId>org.apache.zookeeper</groupId>
<artifactId>zookeeper</artifactId>
<version>3.4.5</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.7</version>
</dependency>
2.常量
public class Constant {
public static final int ZK_SESSION_TIMEOUT = 5000;//zk超时时间
public static final String ZK_REGISTRY_PATH = "/registry";//注册节点
public static final String ZK_DATA_PATH = ZK_REGISTRY_PATH + "/data";//节点
}
3.注册节点
public class ServiceRegistry {
private static final Logger LOGGER = LoggerFactory
.getLogger(ServiceRegistry.class);
private CountDownLatch latch = new CountDownLatch(1);
private String registryAddress;
public ServiceRegistry(String registryAddress) {
//zookeeper的地址
this.registryAddress = registryAddress;
}
/**
* 注册服务
*
* @param data
*/
public void register(String data) {
if (data != null) {
// 连接zookeeper
ZooKeeper zk = connectServer();
if (zk != null) {
// 创建数据节点
createNode(zk, data);
}
}
}
/**
* 创建zookeeper链接,监听
*
* @return
*/
private ZooKeeper connectServer() {
ZooKeeper zk = null;
try {
zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT,
new Watcher() {
public void process(WatchedEvent event) {
if (event.getState() == Event.KeeperState.SyncConnected) {
latch.countDown();
}
}
});
latch.await();
} catch (Exception e) {
LOGGER.error("", e);
}
return zk;
}
/**
* 创建节点
*
* @param zk
* @param data
*/
private void createNode(ZooKeeper zk, String data) {
try {
byte[] bytes = data.getBytes();
if (zk.exists(Constant.ZK_REGISTRY_PATH, null) == null) {
zk.create(Constant.ZK_REGISTRY_PATH, null, Ids.OPEN_ACL_UNSAFE,
CreateMode.PERSISTENT);
}
String path = zk.create(Constant.ZK_DATA_PATH, bytes,
Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
LOGGER.debug("create zookeeper node ({} => {})", path, data);
} catch (Exception e) {
LOGGER.error("", e);
}
}
4.发现节点
/**
* 本类用于client发现server节点的变化 ,实现负载均衡
*
*/
public class ServiceDiscovery {
private static final Logger LOGGER = LoggerFactory
.getLogger(ServiceDiscovery.class);
private CountDownLatch latch = new CountDownLatch(1);
private volatile List<String> dataList = new ArrayList<String>();
private String registryAddress;
/**
* zk链接
*
* @param registryAddress
*/
public ServiceDiscovery(String registryAddress) {
this.registryAddress = registryAddress;
ZooKeeper zk = connectServer();
if (zk != null) {
watchNode(zk);
}
}
/**
* 发现新节点
*
* @return
*/
public String discover() {
String data = null;
int size = dataList.size();
// 存在新节点,使用即可
if (size > 0) {
if (size == 1) {
data = dataList.get(0);
LOGGER.debug("using only data: {}", data);
} else {
data = dataList.get(ThreadLocalRandom.current().nextInt(size));
LOGGER.debug("using random data: {}", data);
}
}
return data;
}
/**
* 链接
*
* @return
*/
private ZooKeeper connectServer() {
ZooKeeper zk = null;
try {
zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT,
new Watcher() {
public void process(WatchedEvent event) {
if (event.getState() == Event.KeeperState.SyncConnected) {
latch.countDown();
}
}
});
latch.await();
} catch (Exception e) {
LOGGER.error("", e);
}
return zk;
}
/**
* 监听
*
* @param zk
*/
private void watchNode(final ZooKeeper zk) {
try {
// 获取所有子节点
List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_PATH,
new Watcher() {
public void process(WatchedEvent event) {
// 节点改变
if (event.getType() == Event.EventType.NodeChildrenChanged) {
watchNode(zk);
}
}
});
List<String> dataList = new ArrayList<String>();
// 循环子节点
for (String node : nodeList) {
// 获取节点中的服务器地址
byte[] bytes = zk.getData(Constant.ZK_REGISTRY_PATH + "/"
+ node, false, null);
// 存储到list中
dataList.add(new String(bytes));
}
LOGGER.debug("node data: {}", dataList);
// 将节点信息记录在成员变量
this.dataList = dataList;
} catch (Exception e) {
LOGGER.error("", e);
}
}
}
common
1.pom
<dependencies>
<!-- SLF4J -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.7</version>
</dependency>
<!-- Netty -->
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.0.24.Final</version>
</dependency>
<!-- Protostuff -->
<dependency>
<groupId>com.dyuproject.protostuff</groupId>
<artifactId>protostuff-core</artifactId>
<version>1.0.8</version>
</dependency>
<dependency>
<groupId>com.dyuproject.protostuff</groupId>
<artifactId>protostuff-runtime</artifactId>
<version>1.0.8</version>
</dependency>
<!-- Objenesis -->
<dependency>
<groupId>org.objenesis</groupId>
<artifactId>objenesis</artifactId>
<version>2.1</version>
</dependency>
</dependencies>
2.编码
/**
* RPC 编码器
*
*/
public class RpcEncoder extends MessageToByteEncoder {
private Class<?> genericClass;
// 构造函数传入向反序列化的class
public RpcEncoder(Class<?> genericClass) {
this.genericClass = genericClass;
}
@Override
public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out)
throws Exception {
//序列化
if (genericClass.isInstance(in)) {
byte[] data = SerializationUtil.serialize(in);
out.writeInt(data.length);
out.writeBytes(data);
}
}
}
3.解码
/**
* RPC 解码器
*
*/
public class RpcDecoder extends ByteToMessageDecoder {
private Class<?> genericClass;
// 构造函数传入向反序列化的class
public RpcDecoder(Class<?> genericClass) {
this.genericClass = genericClass;
}
@Override
public final void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (in.readableBytes() < 4) {
return;
}
in.markReaderIndex();
int dataLength = in.readInt();
if (dataLength < 0) {
ctx.close();
}
if (in.readableBytes() < dataLength) {
in.resetReaderIndex();
}
//将ByteBuf转换为byte[]
byte[] data = new byte[dataLength];
in.readBytes(data);
//将data转换成object
Object obj = SerializationUtil.deserialize(data, genericClass);
out.add(obj);
}
}
4.请求
public class RpcRequest {
private String requestId;
private String className;
private String methodName;
private Class<?>[] parameterTypes;
private Object[] parameters;
//get-set ......
}
5.响应
private String requestId;
private Throwable error;
private Object result;
//get-set ......
6.序列化
private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<Class<?>, Schema<?>>();
private static Objenesis objenesis = new ObjenesisStd(true);
private SerializationUtil() {
}
/**
* 获取类的schema
* @param cls
* @return
*/
@SuppressWarnings("unchecked")
private static <T> Schema<T> getSchema(Class<T> cls) {
Schema<T> schema = (Schema<T>) cachedSchema.get(cls);
if (schema == null) {
schema = RuntimeSchema.createFrom(cls);
if (schema != null) {
cachedSchema.put(cls, schema);
}
}
return schema;
}
/**
* 序列化(对象 -> 字节数组)
*/
@SuppressWarnings("unchecked")
public static <T> byte[] serialize(T obj) {
Class<T> cls = (Class<T>) obj.getClass();
LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
try {
Schema<T> schema = getSchema(cls);
return ProtostuffIOUtil.toByteArray(obj, schema, buffer);//序列化
} catch (Exception e) {
throw new IllegalStateException(e.getMessage(), e);
} finally {
buffer.clear();
}
}
/**
* 反序列化(字节数组 -> 对象)
*/
public static <T> T deserialize(byte[] data, Class<T> cls) {
try {
/*
* 如果一个类没有参数为空的构造方法时候,那么你直接调用newInstance方法试图得到一个实例对象的时候是会抛出异常的
* 通过ObjenesisStd可以完美的避开这个问题
* */
T message = (T) objenesis.newInstance(cls);//实例化
Schema<T> schema = getSchema(cls);//获取类的schema
ProtostuffIOUtil.mergeFrom(data, message, schema);
return message;
} catch (Exception e) {
throw new IllegalStateException(e.getMessage(), e);
}
}
server
1.pom
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.7</version>
</dependency>
<!-- Spring -->
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
<version>3.2.12.RELEASE</version>
</dependency>
<!-- Apache Commons Collections -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-collections4</artifactId>
<version>4.0</version>
</dependency>
<!-- CGLib -->
<dependency>
<groupId>cglib</groupId>
<artifactId>cglib</artifactId>
<version>3.1</version>
</dependency>
<!-- RPC Common -->
<dependency>
<groupId>org.apache.rpc</groupId>
<artifactId>common</artifactId>
<version>0.0.1-SNAPSHOT</version>
</dependency>
<!-- RPC Registry -->
<dependency>
<groupId>org.apache.rpc</groupId>
<artifactId>registry</artifactId>
<version>0.0.1-SNAPSHOT</version>
</dependency>
2.RPC处理器
public class RpcHandler extends SimpleChannelInboundHandler<RpcRequest> {
private static final Logger LOGGER = LoggerFactory
.getLogger(RpcHandler.class);
private final Map<String, Object> handlerMap;
public RpcHandler(Map<String, Object> handlerMap) {
this.handlerMap = handlerMap;
}
/**
* 接收消息,处理消息,返回结果
*/
@Override
public void channelRead0(final ChannelHandlerContext ctx, RpcRequest request)
throws Exception {
RpcResponse response = new RpcResponse();
response.setRequestId(request.getRequestId());
try {
Object result = handle(request);
response.setResult(result);
} catch (Throwable t) {
response.setError(t);
}
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
}
private Object handle(RpcRequest request) throws Throwable {
String className = request.getClassName();
Object serviceBean = handlerMap.get(className);
String methodName = request.getMethodName();
Class<?>[] parameterTypes = request.getParameterTypes();
Object[] parameters = request.getParameters();
Class<?> forName = Class.forName(className);
Method method = forName.getMethod(methodName, parameterTypes);
return method.invoke(serviceBean, parameters);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOGGER.error("server caught exception", cause);
ctx.close();
}
}
3.RPC服务
public class RpcServer implements ApplicationContextAware, InitializingBean {
private static final Logger LOGGER = LoggerFactory
.getLogger(RpcServer.class);
private String serverAddress;
private ServiceRegistry serviceRegistry;
private Map<String, Object> handlerMap = new HashMap<String, Object>();
public RpcServer(String serverAddress) {
this.serverAddress = serverAddress;
}
public RpcServer(String serverAddress, ServiceRegistry serviceRegistry) {
this.serverAddress = serverAddress;
this.serviceRegistry = serviceRegistry;
}
/**
* 通过注解,获取RpcService.class,将它放到map中
*/
public void setApplicationContext(ApplicationContext ctx)
throws BeansException {
Map<String, Object> serviceBeanMap = ctx
.getBeansWithAnnotation(RpcService.class);
if (MapUtils.isNotEmpty(serviceBeanMap)) {
for (Object serviceBean : serviceBeanMap.values()) {
String interfaceName = serviceBean.getClass()
.getAnnotation(RpcService.class).value().getName();
handlerMap.put(interfaceName, serviceBean);
}
}
}
/**
* 服务端启动类
*/
public void afterPropertiesSet() throws Exception {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap
.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel channel)
throws Exception {
channel.pipeline()
.addLast(new RpcDecoder(RpcRequest.class))// 注册解码
.addLast(new RpcEncoder(RpcResponse.class))// 注册编码
.addLast(new RpcHandler(handlerMap));//注册RpcHandler
}
}).option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
String[] array = serverAddress.split(":");
String host = array[0];
int port = Integer.parseInt(array[1]);
ChannelFuture future = bootstrap.bind(host, port).sync();
LOGGER.debug("server started on port {}", port);
if (serviceRegistry != null) {
serviceRegistry.register(serverAddress);
}
future.channel().closeFuture().sync();
} finally {
workerGroup.shutdownGracefully();
bossGroup.shutdownGracefully();
}
}
}
4.自定义RPC注解
@Target({ ElementType.TYPE })//注解用在接口上
@Retention(RetentionPolicy.RUNTIME)//VM将在运行期也保留注释,因此可以通过反射机制读取注解的信息
@Component
public @interface RpcService {
Class<?> value();
}
Cilent
1.pom
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.7</version>
</dependency>
<!-- Netty -->
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.0.24.Final</version>
</dependency>
<!-- RPC Common -->
<dependency>
<groupId>org.apache.rpc</groupId>
<artifactId>common</artifactId>
<version>0.0.1-SNAPSHOT</version>
</dependency>
<!-- RPC Registry -->
<dependency>
<groupId>org.apache.rpc</groupId>
<artifactId>registry</artifactId>
<version>0.0.1-SNAPSHOT</version>
</dependency>
2.RPC代理
public class RpcProxy {
private String serverAddress;
private ServiceDiscovery serviceDiscovery;
public RpcProxy(String serverAddress) {
this.serverAddress = serverAddress;
}
public RpcProxy(ServiceDiscovery serviceDiscovery) {
this.serviceDiscovery = serviceDiscovery;
}
/**
* 创建代理
*
* @param interfaceClass
* @return
*/
@SuppressWarnings("unchecked")
public <T> T create(Class<?> interfaceClass) {
return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(),
new Class<?>[] { interfaceClass }, new InvocationHandler() {
public Object invoke(Object proxy, Method method,
Object[] args) throws Throwable {
//创建RpcRequest,封装被代理类的属性
RpcRequest request = new RpcRequest();
request.setRequestId(UUID.randomUUID().toString());
request.setClassName(method.getDeclaringClass()
.getName());
request.setMethodName(method.getName());
request.setParameterTypes(method.getParameterTypes());
request.setParameters(args);
//查找服务
if (serviceDiscovery != null) {
serverAddress = serviceDiscovery.discover();
}
//随机获取服务的地址
String[] array = serverAddress.split(":");
String host = array[0];
int port = Integer.parseInt(array[1]);
//创建RpcClient,链接服务端
RpcClient client = new RpcClient(host, port);
//通过netty实现RPC
RpcResponse response = client.send(request);
//返回信息
if (response.isError()) {
throw response.getError();
} else {
return response.getResult();
}
}
});
}
}
3.RPC客户端
public class RpcClient extends SimpleChannelInboundHandler<RpcResponse> {
private static final Logger LOGGER = LoggerFactory
.getLogger(RpcClient.class);
private String host;
private int port;
private RpcResponse response;
private final Object obj = new Object();
public RpcClient(String host, int port) {
this.host = host;
this.port = port;
}
/**
* 链接服务端,发送消息
*
* @param request
* @return
* @throws Exception
*/
public RpcResponse send(RpcRequest request) throws Exception {
EventLoopGroup group = new NioEventLoopGroup();
try {
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(group).channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel channel)
throws Exception {
// 向pipeline中添加编码、解码、业务处理的handler
channel.pipeline()
.addLast(new RpcEncoder(RpcRequest.class))
.addLast(new RpcDecoder(RpcResponse.class))
.addLast(RpcClient.this);
}
}).option(ChannelOption.SO_KEEPALIVE, true);
// 链接服务器
ChannelFuture future = bootstrap.connect(host, port).sync();
future.channel().writeAndFlush(request).sync();
// 用线程等待的方式决定是否关闭连接
synchronized (obj) {
obj.wait();
}
if (response != null) {
future.channel().closeFuture().sync();
}
return response;
} finally {
group.shutdownGracefully();
}
}
/**
* 读取服务端的返回结果
*/
@Override
public void channelRead0(ChannelHandlerContext ctx, RpcResponse response)
throws Exception {
this.response = response;
synchronized (obj) {
obj.notifyAll();
}
}
/**
* 异常处理
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
LOGGER.error("client caught exception", cause);
ctx.close();
}
}
网友评论