自定义过滤器一般实现org.apache.shiro.web.filter.authc.AuthenticatingFilter类;
AuthenticatingFilter继承AccessControlFilter ;
onAccessDenied方法是校验的入口,定义校验的流程;
executeLogin方法定义校验的过程;
在executeLogin,执行createToken方法,封装token;
然后,执行登录过程;
如果登录成功,执行onLoginSuccess;
如果登录失败,执行onLoginFailure;
登录失败,一般抛出AuthenticationException异常;
在response的流写入提示信息,交给前端处理即可;
示例代码:
package cn.xo68.boot.auth.server.shiro.filter;
import java.io.IOException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import cn.xo68.boot.auth.core.domain.OAuth2AuthenticationToken;
import cn.xo68.boot.auth.core.domain.Oauth2Principal;
import cn.xo68.boot.auth.core.properties.OAuthResourceProperties;
import cn.xo68.boot.auth.server.properties.AuthServerProperties;
import cn.xo68.core.util.StringTools;
import org.apache.oltu.oauth2.client.OAuthClient;
import org.apache.oltu.oauth2.client.URLConnectionClient;
import org.apache.oltu.oauth2.client.request.OAuthClientRequest;
import org.apache.oltu.oauth2.client.response.OAuthAccessTokenResponse;
import org.apache.oltu.oauth2.common.OAuth;
import org.apache.oltu.oauth2.common.exception.OAuthProblemException;
import org.apache.oltu.oauth2.common.message.types.GrantType;
import org.apache.oltu.oauth2.common.message.types.ParameterStyle;
import org.apache.oltu.oauth2.rs.request.OAuthAccessResourceRequest;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.ThreadContext;
import org.apache.shiro.web.filter.authc.AuthenticatingFilter;
import org.apache.shiro.web.servlet.SimpleCookie;
import org.apache.shiro.web.subject.WebSubject;
import org.apache.shiro.web.util.WebUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;
/**
* oauth2 认证过滤器
* @author wuxie
* @date 2018-8-5
*/
public class OAuth2AuthenticationFilter extends AuthenticatingFilter {
private static final Logger logger= LoggerFactory.getLogger(OAuth2AuthenticationFilter.class);
private SimpleCookie accessTokenCookie;
private AuthServerProperties authServerProperties;
private OAuthResourceProperties oAuthResourceProperties;
//oauth2 authc code参数名
private String authcCodeParam = "code";
//客户端id
private String clientId;
//服务器端登录成功/失败后重定向到的客户端地址
private String redirectUrl;
//oauth2服务器响应类型
private String responseType = "code";
private String failureUrl;
public SimpleCookie getAccessTokenCookie() {
return accessTokenCookie;
}
public void setAccessTokenCookie(SimpleCookie accessTokenCookie) {
this.accessTokenCookie = accessTokenCookie;
}
public AuthServerProperties getAuthServerProperties() {
return authServerProperties;
}
public void setAuthServerProperties(AuthServerProperties authServerProperties) {
this.authServerProperties = authServerProperties;
}
public OAuthResourceProperties getoAuthResourceProperties() {
return oAuthResourceProperties;
}
public void setoAuthResourceProperties(OAuthResourceProperties oAuthResourceProperties) {
this.oAuthResourceProperties = oAuthResourceProperties;
}
public void setAuthcCodeParam(String authcCodeParam) {
this.authcCodeParam = authcCodeParam;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public void setRedirectUrl(String redirectUrl) {
this.redirectUrl = redirectUrl;
}
public void setResponseType(String responseType) {
this.responseType = responseType;
}
public void setFailureUrl(String failureUrl) {
this.failureUrl = failureUrl;
}
/**
* 如命名字面意思,根据请求,生成一个令牌。OAuth2AuthenticationToken是我自定义的
*/
@Override
protected AuthenticationToken createToken(ServletRequest request, ServletResponse response) throws Exception {
HttpServletRequest httpRequest = (HttpServletRequest) request;
try{
OAuthAccessResourceRequest oauthRequest = new OAuthAccessResourceRequest(httpRequest, ParameterStyle.HEADER, ParameterStyle.QUERY);
OAuth2AuthenticationToken oAuth2AuthenticationToken=new OAuth2AuthenticationToken();
Oauth2Principal oauth2Principal=new Oauth2Principal();
oAuth2AuthenticationToken.setPrincipal(oauth2Principal);
//令牌
String accessToken = oauthRequest.getAccessToken();
if(StringTools.isEmpty(accessToken)){
accessToken=accessTokenCookie.getValue();
}
if(StringTools.isNotEmpty(accessToken)){
oAuth2AuthenticationToken.setCredential(accessToken);
oauth2Principal.setAccessToken(accessToken);
return oAuth2AuthenticationToken;
}else {
//authorize_code
String code = httpRequest.getParameter(authcCodeParam);
if(StringTools.isNotEmpty(code)){
//换令牌
//accessToken="";
//oAuth2AuthenticationToken.setAccessToken(accessToken);
//return oAuth2AuthenticationToken;
OAuthClient oAuthClient = new OAuthClient(new URLConnectionClient());
OAuthClientRequest accessTokenRequest = OAuthClientRequest
.tokenLocation(oAuthResourceProperties.getAccessTokenUrl())
.setGrantType(GrantType.AUTHORIZATION_CODE)
.setClientId(oAuthResourceProperties.getClientId())
.setClientSecret(oAuthResourceProperties.getClientSecret())
.setCode(code)
.setRedirectURI(redirectUrl)
.buildQueryMessage();
OAuthAccessTokenResponse oAuthResponse = oAuthClient.accessToken(accessTokenRequest, OAuth.HttpMethod.POST);
accessToken=oAuthResponse.getAccessToken();
oAuth2AuthenticationToken.setCredential(accessToken);
oauth2Principal.setAccessToken(accessToken);
return oAuth2AuthenticationToken;
}
}
}catch (OAuthProblemException e){
logger.warn("过滤器中获取令牌令牌异常", e);
}
return new OAuth2AuthenticationToken();
}
/**
* 根据请求信息,参数等信息判断是否允许通过,如果返回false,则是不通过。最终是否去访问web处理,有isAccessAllowed,onAccessDenied方法共同或运算决定,也就是只要有一个是true就会访问web控制器或action。
*/
@Override
protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) {
return false;
}
/**
*根据请求,拒绝通过处理,如果返回false,则不再去访问web控制器或action
*/
@Override
protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
String error = request.getParameter("error");
String errorDescription = request.getParameter("error_description");
if(!StringUtils.isEmpty(error)) {//如果服务端返回了错误,也就是服务端里检查不通过进入if里返回的错误
WebUtils.issueRedirect(request, response, authServerProperties.getUnauthorizedUrl() + "?error=" + error + "error_description=" + errorDescription);
return false;
}
// Subject subject = getSubject(request, response);
// if(!subject.isAuthenticated()) {
// if(StringUtils.isEmpty(request.getParameter(authcCodeParam))) {
// //如果用户没有身份验证,且没有auth code,则重定向到服务端授权,即访问AuthorizeController的authorize方法
// saveRequestAndRedirectToLogin(request, response);
// return false;
// }
// }
return executeLogin(request, response);
}
/**
* 登录验证处理,父类本来就有,只是有个bug,按官方给出的方法进行了重写
*/
@Override
protected boolean executeLogin(ServletRequest request, ServletResponse response) throws Exception {
AuthenticationToken token = this.createToken(request, response);
if (token == null) {
String msg = "createToken method implementation returned null. A valid non-null AuthenticationToken must be created in order to execute a login attempt.";
throw new IllegalStateException(msg);
} else {
try {
//修复bug代码,也算个不小的坑吧
Subject subject = new WebSubject.Builder(request, response).buildSubject();
subject.login(token);
ThreadContext.bind(subject);
return this.onLoginSuccess(token, subject, request, response);
} catch (AuthenticationException var5) {
return this.onLoginFailure(token, var5, request, response);
}
}
}
@Override
protected boolean onLoginSuccess(AuthenticationToken token, Subject subject, ServletRequest request,
ServletResponse response) throws Exception {
//issueSuccessRedirect(request, response);
Subject msubject=subject;
return true;
}
/**
*登录失败处理(认证令牌验证失败)
*/
@Override
protected boolean onLoginFailure(AuthenticationToken token, AuthenticationException ae, ServletRequest request,
ServletResponse response) {
Subject subject = getSubject(request, response);
if (subject.isAuthenticated() || subject.isRemembered()) {
try {
issueSuccessRedirect(request, response);
} catch (Exception e) {
e.printStackTrace();
}
} else {
try {
WebUtils.issueRedirect(request, response, authServerProperties.getUnauthorizedUrl());
} catch (IOException e) {
e.printStackTrace();
}
}
return false;
}
}
网友评论