美文网首页工作专题java学习Spring Boot-Shiro
Spring Boot 集成Shiro,前后端分离权限校验,自定

Spring Boot 集成Shiro,前后端分离权限校验,自定

作者: 不敢预言的预言家 | 来源:发表于2018-06-29 09:31 被阅读1204次

BB两句,Shiro的坑是真的太多了,和Spring Boot集成的时候更是多上加多

总结的,教程的文章太多了,大家有兴趣自己去网上搜索一下吧。
本着 拎包入住,粘贴可用 的原则,直接上代码。

项目源码:https://github.com/dk980241/spring-boot-template

涉及功能:

  • Shiro使用配置
  • Shiro redis 缓存
  • Shiro session redis
  • Shiro前后端分离校验URL
  • Shiro自定义返回信息,不做跳转

注意点和一些个人想法都在代码的注释里。

ShiroConfig.java

package site.yuyanjia.template.common.config;

import lombok.extern.slf4j.Slf4j;
import org.apache.shiro.authc.credential.HashedCredentialsMatcher;
import org.apache.shiro.cache.Cache;
import org.apache.shiro.cache.CacheException;
import org.apache.shiro.cache.CacheManager;
import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.session.Session;
import org.apache.shiro.session.SessionException;
import org.apache.shiro.session.UnknownSessionException;
import org.apache.shiro.session.mgt.eis.AbstractSessionDAO;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.web.filter.AccessControlFilter;
import org.apache.shiro.web.filter.authc.LogoutFilter;
import org.apache.shiro.web.filter.authz.AuthorizationFilter;
import org.apache.shiro.web.filter.mgt.DefaultFilter;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.apache.shiro.web.session.mgt.DefaultWebSessionManager;
import org.apache.shiro.web.util.WebUtils;
import org.springframework.boot.autoconfigure.AutoConfigureAfter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ScanOptions;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.http.HttpMethod;
import org.springframework.util.Assert;
import site.yuyanjia.template.website.realm.WebUserRealm;

import javax.servlet.Filter;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * shiro配置
 *
 * @author seer
 * @date 2018/2/1 15:41
 */
@Configuration
@AutoConfigureAfter(RedisConfig.class)
@ConfigurationProperties(prefix = ShiroConfig.PREFIX)
@Slf4j
public class ShiroConfig {

    public static final String PREFIX = "yuyanjia.shiro";

    /**
     * Url和Filter匹配关系
     */
    private List<String> urlFilterList = new ArrayList<>();

    /**
     * 散列算法
     */
    private String hashAlgorithm = "MD5";

    /**
     * 散列迭代次数
     */
    private Integer hashIterations = 2;

    /**
     * 缓存 key 前缀
     */
    private static final String SHIRO_REDIS_CACHE_KEY_PREFIX = ShiroConfig.class.getName() + "_shiro.redis.cache_";

    /**
     * session key 前缀
     */
    private static final String SHIRO_REDIS_SESSION_KEY_PREFIX = ShiroConfig.class.getName() + "shiro.redis.session_";

    /**
     * Filter 工厂
     * <p>
     * 通过自定义 Filter 实现校验逻辑的重写和返回值的定义 {@link ShiroFilterFactoryBean#setFilters(java.util.Map)
     * 对一个 URL 要进行多个 Filter 的校验。通过 {@link ShiroFilterFactoryBean#setFilterChainDefinitions(java.lang.String)} 实现
     * 通过 {@link ShiroFilterFactoryBean#setFilterChainDefinitionMap(java.util.Map)} 实现的拦截不方便实现实现多 Filter 校验,所以这里没有使用
     * <p>
     * 权限的名称可以随便指定的,和 URL 配置的 Filter 有关,这里使用 {@link DefaultFilter} 默认的的权限定义,覆盖了原权限拦截器
     * 授权Filter {@link WebUserFilter}
     * 权限Filter {@link WebPermissionsAuthorizationFilter}
     * 登出Filter {@link WebLogoutFilter}
     *
     * @param securityManager
     * @return
     */
    @Bean
    public ShiroFilterFactoryBean shiroFilter(SecurityManager securityManager) {
        ShiroFilterFactoryBean shiroFilterFactoryBean = new ShiroFilterFactoryBean();
        shiroFilterFactoryBean.setSecurityManager(securityManager);

        Map<String, Filter> filterMap = new LinkedHashMap<>();
        filterMap.put(DefaultFilter.authc.toString(), new WebUserFilter());
        filterMap.put(DefaultFilter.perms.toString(), new WebPermissionsAuthorizationFilter());
        filterMap.put(DefaultFilter.logout.toString(), new WebLogoutFilter());
        shiroFilterFactoryBean.setFilters(filterMap);

        StringBuilder stringBuilder = new StringBuilder();
        urlFilterList.forEach(s -> stringBuilder.append(s).append("\n"));
        shiroFilterFactoryBean.setFilterChainDefinitions(stringBuilder.toString());

        return shiroFilterFactoryBean;
    }

    /**
     * 安全管理器
     *
     * @param userRealm                自定义 realm {@link #userRealm(CacheManager, HashedCredentialsMatcher)}
     * @param shiroRedisSessionManager 自定义 session 管理器 {@link #shiroRedisSessionManager(RedisTemplate)}
     * @return @link org.apache.shiro.mgt.SecurityManager}
     */
    @Bean
    public SecurityManager securityManager(WebUserRealm userRealm, DefaultWebSessionManager shiroRedisSessionManager) {
        DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager();
        securityManager.setRealm(userRealm);
        securityManager.setSessionManager(shiroRedisSessionManager);
        return securityManager;
    }


    /**
     * 凭证计算匹配
     *
     * @return
     */
    @Bean
    public HashedCredentialsMatcher hashedCredentialsMatcher() {
        HashedCredentialsMatcher hashedCredentialsMatcher = new HashedCredentialsMatcher();
        hashedCredentialsMatcher.setHashAlgorithmName(hashAlgorithm);
        hashedCredentialsMatcher.setHashIterations(hashIterations);
        hashedCredentialsMatcher.setStoredCredentialsHexEncoded(true);
        return hashedCredentialsMatcher;
    }

    /**
     * 用户Realm
     * <p>
     * SQL已经实现缓存 {@link site.yuyanjia.template.common.mapper.WebUserMapper}
     * shiro默认缓存这里还有点坑需要填
     *
     * @return
     */
    @Bean
    public WebUserRealm userRealm(CacheManager shiroRedisCacheManager, HashedCredentialsMatcher hashedCredentialsMatcher) {
        WebUserRealm userRealm = new WebUserRealm();
        userRealm.setCredentialsMatcher(hashedCredentialsMatcher);

        userRealm.setCachingEnabled(false);
        userRealm.setAuthenticationCachingEnabled(false);
        userRealm.setAuthorizationCachingEnabled(false);
        userRealm.setCacheManager(shiroRedisCacheManager);
        return userRealm;
    }

    /**
     * 缓存管理器
     *
     * @param redisTemplateWithJdk shiro的对象总是有这样那样的问题,所以 redisTemplate 使用 {@link org.springframework.data.redis.serializer.JdkSerializationRedisSerializer} 序列化值
     * @return
     */
    @Bean
    public CacheManager shiroRedisCacheManager(RedisTemplate redisTemplateWithJdk) {
        // TODO seer 2018/6/28 17:07 缓存这里反序列化有点问题,需要重写一下
        return new CacheManager() {
            @Override
            public <K, V> Cache<K, V> getCache(String s) throws CacheException {
                log.trace("shiro redis cache manager get cache. name={} ", s);

                return new Cache<K, V>() {
                    @Override
                    public V get(K k) throws CacheException {
                        log.trace("shiro redis cache get.{} K={}", s, k);
                        return ((V) redisTemplateWithJdk.opsForValue().get(generateCacheKey(s, k)));
                    }

                    @Override
                    public V put(K k, V v) throws CacheException {
                        log.trace("shiro redis cache put.{} K={} V={}", s, k, v);
                        V result = (V) redisTemplateWithJdk.opsForValue().get(generateCacheKey(s, k));

                        redisTemplateWithJdk.opsForValue().set(generateCacheKey(s, k), v);
                        return result;
                    }

                    @Override
                    public V remove(K k) throws CacheException {
                        log.trace("shiro redis cache remove.{} K={}", s, k);
                        V result = (V) redisTemplateWithJdk.opsForValue().get(generateCacheKey(s, k));

                        redisTemplateWithJdk.delete(generateCacheKey(s, k));
                        return result;
                    }

                    /**
                     * clear
                     * <p>
                     *     redis keys 命令会造成堵塞
                     *     redis scan 命令不会造成堵塞
                     *
                     * @throws CacheException
                     */
                    @Override
                    public void clear() throws CacheException {
                        log.trace("shiro redis cache clear.{}", s);
                        RedisConnection redisConnection = redisTemplateWithJdk.getConnectionFactory().getConnection();
                        Assert.notNull(redisConnection, "redisConnection is null");
                        try (Cursor<byte[]> cursor = redisConnection.scan(ScanOptions.scanOptions()
                                .match(generateCacheKey(s, "*"))
                                .count(Integer.MAX_VALUE)
                                .build())) {
                            while (cursor.hasNext()) {
                                redisConnection.del(cursor.next());
                            }
                        } catch (IOException e) {
                            log.error("shiro redis cache clear exception", e);
                        }
                    }

                    @Override
                    public int size() {
                        log.trace("shiro redis cache size.{}", s);
                        AtomicInteger count = new AtomicInteger(0);
                        RedisConnection redisConnection = redisTemplateWithJdk.getConnectionFactory().getConnection();
                        Assert.notNull(redisConnection, "redisConnection is null");
                        try (Cursor<byte[]> cursor = redisConnection.scan(ScanOptions.scanOptions()
                                .match(generateCacheKey(s, "*"))
                                .count(Integer.MAX_VALUE)
                                .build())) {
                            while (cursor.hasNext()) {
                                count.getAndIncrement();
                            }
                        } catch (IOException e) {
                            log.error("shiro redis cache size exception", e);
                        }
                        return count.get();
                    }

                    @Override
                    public Set<K> keys() {
                        log.trace("shiro redis cache keys.{}", s);
                        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
                        Set<K> keys = new HashSet<>();
                        RedisConnection redisConnection = redisTemplateWithJdk.getConnectionFactory().getConnection();
                        Assert.notNull(redisConnection, "redisConnection is null");
                        try (Cursor<byte[]> cursor = redisConnection.scan(ScanOptions.scanOptions()
                                .match(generateCacheKey(s, "*"))
                                .count(Integer.MAX_VALUE)
                                .build())) {
                            while (cursor.hasNext()) {
                                keys.add((K) stringRedisSerializer.deserialize(cursor.next()));
                            }
                        } catch (IOException e) {
                            log.error("shiro redis cache keys exception", e);
                        }
                        return keys;
                    }

                    @Override
                    public Collection<V> values() {
                        return null;
                    }
                };
            }
        };
    }


    /**
     * session管理器
     *
     * @param redisTemplateWithJdk shiro的对象总是有这样那样的问题,所以 redisTemplate 使用 {@link org.springframework.data.redis.serializer.JdkSerializationRedisSerializer} 序列化值
     * @return
     */
    @Bean
    public DefaultWebSessionManager shiroRedisSessionManager(RedisTemplate redisTemplateWithJdk) {
        DefaultWebSessionManager defaultWebSessionManager = new DefaultWebSessionManager();
        defaultWebSessionManager.setGlobalSessionTimeout(1800000);
        defaultWebSessionManager.setSessionValidationInterval(900000);
        defaultWebSessionManager.setDeleteInvalidSessions(true);
        defaultWebSessionManager.setSessionDAO(
                new AbstractSessionDAO() {
                    @Override
                    protected Serializable doCreate(Session session) {
                        Serializable sessionId = this.generateSessionId(session);
                        log.trace("shiro redis session create. sessionId={}", sessionId);
                        this.assignSessionId(session, sessionId);
                        redisTemplateWithJdk.opsForValue().set(generateSessionKey(sessionId), session, session.getTimeout(), TimeUnit.MILLISECONDS);
                        return sessionId;
                    }

                    @Override
                    protected Session doReadSession(Serializable sessionId) {
                        log.trace("shiro redis session read. sessionId={}", sessionId);
                        return (Session) redisTemplateWithJdk.opsForValue().get(generateSessionKey(sessionId));
                    }

                    @Override
                    public void update(Session session) throws UnknownSessionException {
                        log.trace("shiro redis session update. sessionId={}", session.getId());
                        redisTemplateWithJdk.opsForValue().set(generateSessionKey(session.getId()), session, session.getTimeout(), TimeUnit.MILLISECONDS);
                    }

                    @Override
                    public void delete(Session session) {
                        log.trace("shiro redis session delete. sessionId={}", session.getId());
                        redisTemplateWithJdk.delete(generateSessionKey(session.getId()));
                    }

                    @Override
                    public Collection<Session> getActiveSessions() {
                        Set<Session> sessionSet = new HashSet<>();
                        RedisConnection redisConnection = redisTemplateWithJdk.getConnectionFactory().getConnection();
                        Assert.notNull(redisConnection, "redisConnection is null");
                        try (Cursor<byte[]> cursor = redisConnection.scan(ScanOptions.scanOptions()
                                .match(generateSessionKey("*"))
                                .count(Integer.MAX_VALUE)
                                .build())) {
                            while (cursor.hasNext()) {
                                Session session = (Session) redisTemplateWithJdk.opsForValue().get(cursor.next());
                                sessionSet.add(session);
                            }
                        } catch (IOException e) {
                            log.error("shiro redis session getActiveSessions exception", e);
                        }
                        return sessionSet;
                    }
                }
        );

        return defaultWebSessionManager;
    }

    /**
     * 生成 缓存 key
     *
     * @param name
     * @param key
     * @return
     */
    private String generateCacheKey(String name, Object key) {
        return SHIRO_REDIS_CACHE_KEY_PREFIX + name + "_" + key;
    }

    /**
     * 生成 session key
     *
     * @param key
     * @return
     */
    private String generateSessionKey(Object key) {
        return SHIRO_REDIS_SESSION_KEY_PREFIX + "_" + key;
    }


    /**
     * 重写用户filter
     * <p>
     * shiro 默认 {@link org.apache.shiro.web.filter.authc.UserFilter}
     *
     * @author seer
     * @date 2018/6/17 22:30
     */
    class WebUserFilter extends AccessControlFilter {
        @Override
        protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) throws Exception {
           response.setContentType("application/json");
            if (isLoginRequest(request, response)) {
                return true;
            }

            Subject subject = getSubject(request, response);
            if (subject.getPrincipal() != null) {
                return true;
            }
            response.getWriter().write("{\"response_code\":\"9000\",\"response_msg\":\"登录过期\"}");
            return false;
        }

        /**
         * 不要做任何处理跳转,直接return,进行下一个filter
         *
         * @param request
         * @param response
         * @return
         * @throws Exception
         */
        @Override
        protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
            return false;
        }
    }

    /**
     * 重写权限filter
     * <p>
     * shiro 默认 {@link org.apache.shiro.web.filter.authz.PermissionsAuthorizationFilter}
     * <p>
     * 前后端分离项目,直接获取url进行匹配,后台配置的权限的值就是请求路径 {@link WebUserRealm#doGetAuthorizationInfo(PrincipalCollection)}
     *
     * @author seer
     * @date 2018/6/17 22:41
     */
    class WebPermissionsAuthorizationFilter extends AuthorizationFilter {
        @Override
        protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) throws Exception {
            Subject subject = getSubject(request, response);
            HttpServletRequest httpServletRequest = ((HttpServletRequest) request);
            String url = httpServletRequest.getServletPath();
            if (subject.isPermitted(url)) {
                return true;
            }
            response.getWriter().write("{\"response_code\":\"90001\",\"response_msg\":\"权限不足\"}");
            return false;
        }

        @Override
        protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws IOException {
            return false;
        }
    }

    /**
     * 重写登出filter
     * shiro 默认 {@link LogoutFilter}
     *
     * @author seer
     * @date 2018/6/26 2:09
     */
    class WebLogoutFilter extends LogoutFilter {
        @Override
        protected boolean preHandle(ServletRequest request, ServletResponse response) throws Exception {
            response.getWriter().write("{\"response_code\":\"0000\",\"response_msg\":\"SUCCES\"}");
            Subject subject = getSubject(request, response);

            if (isPostOnlyLogout()) {
                if (!WebUtils.toHttp(request).getMethod().toUpperCase(Locale.ENGLISH).equals(HttpMethod.POST.toString())) {
                    return onLogoutRequestNotAPost(request, response);
                }
            }
            try {
                subject.logout();
            } catch (SessionException ise) {
                log.trace("Encountered session exception during logout.  This can generally safely be ignored.", ise);
            }
            return false;
        }
    }

    public List<String> getUrlFilterList() {
        return urlFilterList;
    }

    public void setUrlFilterList(List<String> urlFilterList) {
        this.urlFilterList = urlFilterList;
    }

    public String getHashAlgorithm() {
        return hashAlgorithm;
    }

    public void setHashAlgorithm(String hashAlgorithm) {
        this.hashAlgorithm = hashAlgorithm;
    }

    public Integer getHashIterations() {
        return hashIterations;
    }

    public void setHashIterations(Integer hashIterations) {
        this.hashIterations = hashIterations;
    }
}

application-yml

yuyanjia:
  shiro:
    url-filter-list:
      - /website/user/user-login=anon
      - /website/user/user-logout=logout
      - /website/user/**=authc,perms
      - /**=anon

WebUserRealm

package site.yuyanjia.template.website.realm;

import org.apache.commons.collections.CollectionUtils;
import org.apache.shiro.authc.*;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.util.ByteSource;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.ObjectUtils;
import site.yuyanjia.template.common.mapper.WebPermissionMapper;
import site.yuyanjia.template.common.mapper.WebRolePermissionMapper;
import site.yuyanjia.template.common.mapper.WebUserMapper;
import site.yuyanjia.template.common.mapper.WebUserRoleMapper;
import site.yuyanjia.template.common.model.WebPermissionDO;
import site.yuyanjia.template.common.model.WebRolePermissionDO;
import site.yuyanjia.template.common.model.WebUserDO;
import site.yuyanjia.template.common.model.WebUserRoleDO;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;


/**
 * 用户Realm
 *
 * @author seer
 * @date 2018/2/1 16:59
 */
public class WebUserRealm extends AuthorizingRealm {

    @Autowired
    private WebUserMapper webUserMapper;

    @Autowired
    private WebUserRoleMapper webUserRoleMapper;

    @Autowired
    private WebRolePermissionMapper webRolePermissionMapper;

    @Autowired
    private WebPermissionMapper webPermissionMapper;

    /**
     * 获取授权信息
     * <p>
     * 权限的值是前端ajax请求的路径,角色的存在是为了方便给用户批量赋值权限的。
     * 项目的最终实现是针对用户和权限的关系,不对角色作校验
     *
     * @param principalCollection
     * @return
     */
    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principalCollection) {
        /**
         * 如果项目使用了 spring-boot-devtools 会导致类加载不同
         * jar 使用 {@link sun.misc.Launcher.AppClassLoader}
         * spring-boot-devtools 使用 {@link org.springframework.boot.devtools.restart.classloader.RestartClassLoader}
         */
        Object obj = principalCollection.getPrimaryPrincipal();
        if (ObjectUtils.isEmpty(obj)) {
            throw new AccountException("用户信息查询为空");
        }
        WebUserDO webUserDO;
        if (obj.getClass().getClassLoader().equals(WebUserDO.class.getClassLoader())) {
            webUserDO = (WebUserDO) obj;
        }else{
            webUserDO = new WebUserDO();
            BeanUtils.copyProperties(obj, webUserDO);
        }

        SimpleAuthorizationInfo authenticationInfo = new SimpleAuthorizationInfo();
        List<WebUserRoleDO> webUserRoleDOList = webUserRoleMapper.selectByUserId(webUserDO.getId());
        if (CollectionUtils.isEmpty(webUserRoleDOList)) {
            return authenticationInfo;
        }

        List<WebRolePermissionDO> webRolePermissionDOList = new ArrayList<>();
        webUserRoleDOList.forEach(
                webUserRoleDO -> webRolePermissionDOList.addAll(webRolePermissionMapper.selectByRoleId(webUserRoleDO.getRoleId()))
        );
        if (CollectionUtils.isEmpty(webRolePermissionDOList)) {
            return authenticationInfo;
        }

        Set<String> permissonSet = webRolePermissionDOList.stream()
                .map(webRolePermissionDO ->
                {
                    WebPermissionDO webPermissionDO = webPermissionMapper.selectByPrimaryKey(webRolePermissionDO.getPermissionId());
                    return webPermissionDO.getPermissionValue();
                })
                .collect(Collectors.toSet());
        authenticationInfo.addStringPermissions(permissonSet);
        return authenticationInfo;
    }

    /**
     * 获取验证信息
     * <p>
     * 将用户实体作为principal方便后续直接使用
     *
     * @param authenticationToken
     * @return
     * @throws AuthenticationException
     */
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken authenticationToken) throws AuthenticationException {
        String username = (String) authenticationToken.getPrincipal();
        WebUserDO webUserDO = webUserMapper.selectByUsername(username);
        if (ObjectUtils.isEmpty(webUserDO)) {
            throw new UnknownAccountException("用户 " + username + " 信息查询失败");
        }

        SimpleAuthenticationInfo authenticationInfo = new SimpleAuthenticationInfo(
                webUserDO,
                webUserDO.getPassword(),
                getName()
        );
        ByteSource salt = ByteSource.Util.bytes(webUserDO.getSalt());
        authenticationInfo.setCredentialsSalt(salt);
        return authenticationInfo;
    }

    /**
     * 删除缓存
     *
     * @param principals
     */
    @Override
    protected void doClearCache(PrincipalCollection principals) {
        super.doClearCache(principals);
    }
}

具体的登录登出等使用方式不做赘述。

相关文章

网友评论

  • 十九贝勒:好长啊 ,为什么不使用spring-security?
    Carnation_Sean:@十九贝勒 ‘前后端分离权限检验’,前端检验呢?
    不敢预言的预言家:@十九贝勒 因为shiro相比Spring Security 简单很多。我们这种敏捷开发,基本都用shiro

本文标题:Spring Boot 集成Shiro,前后端分离权限校验,自定

本文链接:https://www.haomeiwen.com/subject/gpcnyftx.html