diff --git a/pom.xml b/pom.xml index 867c06e..ed0c0c6 100644 --- a/pom.xml +++ b/pom.xml @@ -71,6 +71,11 @@ zkh-data ${project.version} + + org.springdoc + springdoc-openapi-common + 1.8.0 + diff --git a/zkh-common/pom.xml b/zkh-common/pom.xml index 95c4c7d..d45743d 100644 --- a/zkh-common/pom.xml +++ b/zkh-common/pom.xml @@ -30,6 +30,10 @@ org.springframework.data spring-data-jpa + + org.springdoc + springdoc-openapi-common + diff --git a/zkh-common/src/main/java/vip/jcfd/common/dto/LoginResponse.java b/zkh-common/src/main/java/vip/jcfd/common/dto/LoginResponse.java new file mode 100644 index 0000000..f6eda39 --- /dev/null +++ b/zkh-common/src/main/java/vip/jcfd/common/dto/LoginResponse.java @@ -0,0 +1,25 @@ +package vip.jcfd.common.dto; + +import io.swagger.v3.oas.annotations.media.Schema; + +/** + * 登录响应DTO + */ +@Schema(description = "登录响应") +public record LoginResponse( + + @Schema(description = "访问令牌", example = "550e8400-e29b-41d4-a716-446655440000") + String accessToken, + + @Schema(description = "刷新令牌", example = "550e8400-e29b-41d4-a716-446655440001") + String refreshToken, + + @Schema(description = "令牌类型", example = "Bearer") + String tokenType, + + @Schema(description = "访问令牌过期时间(秒)", example = "1800") + long expiresIn, + + @Schema(description = "用户名", example = "admin") + String username +) {} diff --git a/zkh-common/src/main/java/vip/jcfd/common/dto/TokenRefreshRequest.java b/zkh-common/src/main/java/vip/jcfd/common/dto/TokenRefreshRequest.java new file mode 100644 index 0000000..31fc11d --- /dev/null +++ b/zkh-common/src/main/java/vip/jcfd/common/dto/TokenRefreshRequest.java @@ -0,0 +1,21 @@ +package vip.jcfd.common.dto; + +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotBlank; + +/** + * Token刷新请求DTO + */ +@Schema(description = "Token刷新请求") +public record TokenRefreshRequest( + + @Parameter(description = "刷新令牌") + @NotBlank(message = "刷新令牌不能为空") + @Schema(description = "刷新令牌", example = "550e8400-e29b-41d4-a716-446655440000") + String refreshToken, + + @Parameter(description = "设备标识") + @Schema(description = "设备标识", example = "web-desktop", required = false) + String deviceId +) {} diff --git a/zkh-common/src/main/java/vip/jcfd/common/dto/TokenRefreshResponse.java b/zkh-common/src/main/java/vip/jcfd/common/dto/TokenRefreshResponse.java new file mode 100644 index 0000000..38f50e4 --- /dev/null +++ b/zkh-common/src/main/java/vip/jcfd/common/dto/TokenRefreshResponse.java @@ -0,0 +1,19 @@ +package vip.jcfd.common.dto; + +import io.swagger.v3.oas.annotations.media.Schema; + +/** + * Token刷新响应DTO + */ +@Schema(description = "Token刷新响应") +public record TokenRefreshResponse( + + @Schema(description = "新的访问令牌", example = "550e8400-e29b-41d4-a716-446655440000") + String accessToken, + + @Schema(description = "新的刷新令牌", example = "550e8400-e29b-41d4-a716-446655440001") + String refreshToken, + + @Schema(description = "令牌类型", example = "Bearer") + String tokenType +) {} diff --git a/zkh-web/pom.xml b/zkh-web/pom.xml index 66e24e4..7789ccc 100644 --- a/zkh-web/pom.xml +++ b/zkh-web/pom.xml @@ -35,6 +35,12 @@ com.fasterxml.jackson.datatype jackson-datatype-jsr310 + + org.springdoc + springdoc-openapi-starter-webmvc-ui + 2.8.14 + provided + diff --git a/zkh-web/src/main/java/vip/jcfd/web/auth/CustomDaoAuthenticationProvider.java b/zkh-web/src/main/java/vip/jcfd/web/auth/CustomDaoAuthenticationProvider.java new file mode 100644 index 0000000..76ea409 --- /dev/null +++ b/zkh-web/src/main/java/vip/jcfd/web/auth/CustomDaoAuthenticationProvider.java @@ -0,0 +1,17 @@ +package vip.jcfd.web.auth; + +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.authentication.dao.DaoAuthenticationProvider; +import org.springframework.security.core.userdetails.UserDetailsService; + +public class CustomDaoAuthenticationProvider extends DaoAuthenticationProvider { + + public CustomDaoAuthenticationProvider(UserDetailsService userDetailsService) { + super(userDetailsService); + } + + @Override + public boolean supports(Class authentication) { + return UsernamePasswordAuthenticationToken.class.equals(authentication); + } +} diff --git a/zkh-web/src/main/java/vip/jcfd/web/auth/RefreshTokenAuthProvider.java b/zkh-web/src/main/java/vip/jcfd/web/auth/RefreshTokenAuthProvider.java new file mode 100644 index 0000000..95b0c27 --- /dev/null +++ b/zkh-web/src/main/java/vip/jcfd/web/auth/RefreshTokenAuthProvider.java @@ -0,0 +1,24 @@ +package vip.jcfd.web.auth; + +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.authentication.dao.DaoAuthenticationProvider; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsService; + +public class RefreshTokenAuthProvider extends DaoAuthenticationProvider { + + public RefreshTokenAuthProvider(UserDetailsService userDetailsService) { + super(userDetailsService); + } + + @Override + protected void additionalAuthenticationChecks(UserDetails userDetails, UsernamePasswordAuthenticationToken authentication) throws AuthenticationException { + + } + + @Override + public boolean supports(Class authentication) { + return RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication); + } +} diff --git a/zkh-web/src/main/java/vip/jcfd/web/auth/RefreshTokenAuthenticationToken.java b/zkh-web/src/main/java/vip/jcfd/web/auth/RefreshTokenAuthenticationToken.java new file mode 100644 index 0000000..fd53e25 --- /dev/null +++ b/zkh-web/src/main/java/vip/jcfd/web/auth/RefreshTokenAuthenticationToken.java @@ -0,0 +1,16 @@ +package vip.jcfd.web.auth; + +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.GrantedAuthority; + +import java.util.Collection; + +public class RefreshTokenAuthenticationToken extends UsernamePasswordAuthenticationToken { + public RefreshTokenAuthenticationToken(Object principal, Object credentials) { + super(principal, credentials); + } + + public RefreshTokenAuthenticationToken(Object principal, Object credentials, Collection authorities) { + super(principal, credentials, authorities); + } +} diff --git a/zkh-web/src/main/java/vip/jcfd/web/config/RedisConfig.java b/zkh-web/src/main/java/vip/jcfd/web/config/RedisConfig.java index ed9a1f3..98d4b97 100644 --- a/zkh-web/src/main/java/vip/jcfd/web/config/RedisConfig.java +++ b/zkh-web/src/main/java/vip/jcfd/web/config/RedisConfig.java @@ -1,5 +1,6 @@ package vip.jcfd.web.config; +import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.data.redis.connection.RedisConnectionFactory; @@ -19,8 +20,13 @@ public class RedisConfig { } @Bean - public TokenRedisStorage tokenRedisTemplate(RedisConnectionFactory factory, StringRedisTemplate stringRedisTemplate) { - TokenRedisStorage tokenRedisStorage = new TokenRedisStorage(securityProps.getDuration(), stringRedisTemplate); + public TokenRedisStorage tokenRedisTemplate(RedisConnectionFactory factory, StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper) { + TokenRedisStorage tokenRedisStorage = new TokenRedisStorage( + securityProps.getAccessTokenDuration(), + securityProps.getRefreshTokenDuration(), + stringRedisTemplate, + objectMapper + ); tokenRedisStorage.setConnectionFactory(factory); tokenRedisStorage.setValueSerializer(new JdkSerializationRedisSerializer()); tokenRedisStorage.setKeySerializer(new StringRedisSerializer()); diff --git a/zkh-web/src/main/java/vip/jcfd/web/config/SpringDocConfig.java b/zkh-web/src/main/java/vip/jcfd/web/config/SpringDocConfig.java new file mode 100644 index 0000000..5c4026a --- /dev/null +++ b/zkh-web/src/main/java/vip/jcfd/web/config/SpringDocConfig.java @@ -0,0 +1,72 @@ +package vip.jcfd.web.config; + +import io.swagger.v3.oas.models.Operation; +import io.swagger.v3.oas.models.PathItem; +import io.swagger.v3.oas.models.media.*; +import io.swagger.v3.oas.models.parameters.RequestBody; +import io.swagger.v3.oas.models.responses.ApiResponse; +import io.swagger.v3.oas.models.responses.ApiResponses; +import org.springdoc.core.customizers.OpenApiCustomizer; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration("_springDocConfig") +public class SpringDocConfig { + + @Bean + public OpenApiCustomizer openApiCustomizer() { + return (openAPI) -> { + openAPI.path("/login", new PathItem() + .post(new Operation() + .summary("登录接口") + .description("用于用户登录,返回token") + .addTagsItem("认证管理") + .requestBody(new RequestBody() + .description("帐号密码") + .required(true) + .content(new Content().addMediaType("application/json", new MediaType().schema(new Schema<>() + .addProperty("username", new StringSchema().example("admin")) + .addProperty("password", new StringSchema().example("123456")))))) + .responses(new ApiResponses() + .addApiResponse("成功", new ApiResponse() + .content(new Content().addMediaType("application/json", new MediaType().schema(new Schema<>() + .addProperty("data", new JsonSchema() + .addProperty("accessToken", new StringSchema().example("550e8400-e29b-41d4-a716-446655440000")) + .addProperty("refreshToken", new StringSchema().example("550e8400-e29b-41d4-a716-446655440001")) + .addProperty("tokenType", new StringSchema().example("Bearer")) + .addProperty("expiresIn", new NumberSchema().example(1800)) + .addProperty("username", new StringSchema().example("admin")) + ) + .addProperty("success", new BooleanSchema().example(true)) + .addProperty("code", new IntegerSchema().example(200)) + .addProperty("message", new StringSchema().example("登录成功")) + )))) + .addApiResponse("失败", new ApiResponse() + .content(new Content().addMediaType("application/json", new MediaType().schema(new Schema<>() + .addProperty("data", new StringSchema().example(null)) + .addProperty("success", new BooleanSchema().example(false)) + .addProperty("code", new IntegerSchema().example(401)) + .addProperty("message", new StringSchema().example("用户名或密码错误")) + )))) + ))); + openAPI.path("/logout", new PathItem() + .post(new Operation() + .summary("登出接口") + .description("用于用户登出") + .addTagsItem("认证管理") + .responses(new ApiResponses() + .addApiResponse("成功", new ApiResponse() + .content(new Content().addMediaType("application/json", new MediaType().schema(new Schema<>() + .addProperty("data", new StringSchema().example(null)) + .addProperty("success", new BooleanSchema().example(true)) + .addProperty("code", new IntegerSchema().example(200)) + .addProperty("message", new StringSchema().example("登出成功")) + ) + ) + ) + ) + ) + )); + }; + } +} diff --git a/zkh-web/src/main/java/vip/jcfd/web/config/WebSecurityConfig.java b/zkh-web/src/main/java/vip/jcfd/web/config/WebSecurityConfig.java index 50cd570..e798166 100644 --- a/zkh-web/src/main/java/vip/jcfd/web/config/WebSecurityConfig.java +++ b/zkh-web/src/main/java/vip/jcfd/web/config/WebSecurityConfig.java @@ -19,6 +19,8 @@ import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.dao.DaoAuthenticationProvider; +import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.authentication.configuration.AuthenticationConfiguration; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -27,6 +29,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.web.AuthenticationEntryPoint; @@ -37,6 +40,9 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; import org.springframework.security.web.authentication.logout.LogoutHandler; import vip.jcfd.common.core.R; +import vip.jcfd.common.dto.LoginResponse; +import vip.jcfd.web.auth.CustomDaoAuthenticationProvider; +import vip.jcfd.web.auth.RefreshTokenAuthProvider; import vip.jcfd.web.config.props.SecurityProps; import vip.jcfd.web.filter.JsonUsernamePasswordAuthenticationFilter; import vip.jcfd.web.filter.TokenFilter; @@ -59,10 +65,18 @@ public class WebSecurityConfig { private final ObjectMapper objectMapper; private final TokenRedisStorage tokenRedisStorage; - public WebSecurityConfig(SecurityProps securityProps, ObjectMapper objectMapper, TokenRedisStorage tokenRedisStorage) { + public WebSecurityConfig(SecurityProps securityProps, + ObjectMapper objectMapper, + TokenRedisStorage tokenRedisStorage, + AuthenticationManagerBuilder builder, + UserDetailsService userDetailsService) { this.securityProps = securityProps; this.objectMapper = objectMapper; this.tokenRedisStorage = tokenRedisStorage; + builder.authenticationProvider(new RefreshTokenAuthProvider(userDetailsService)); + DaoAuthenticationProvider authenticationProvider = new CustomDaoAuthenticationProvider(userDetailsService); + authenticationProvider.setPasswordEncoder(new BCryptPasswordEncoder()); + builder.authenticationProvider(authenticationProvider); } @Scheduled(cron = "0 */30 * * * *") @@ -104,8 +118,6 @@ public class WebSecurityConfig { CustomAuthenticationEntryPoint authenticationEntryPoint = new CustomAuthenticationEntryPoint(objectMapper, tokenRedisStorage); http.formLogin(config -> { config.loginProcessingUrl("/login"); - config.failureHandler(authenticationEntryPoint); - config.successHandler(authenticationEntryPoint); }); http.csrf(AbstractHttpConfigurer::disable); http.logout(config -> { @@ -121,6 +133,7 @@ public class WebSecurityConfig { http.addFilterBefore(tokenFilter, UsernamePasswordAuthenticationFilter.class); JsonUsernamePasswordAuthenticationFilter filter = new JsonUsernamePasswordAuthenticationFilter(objectMapper, authenticationManager); filter.setAuthenticationSuccessHandler(authenticationEntryPoint); + filter.setAuthenticationFailureHandler(authenticationEntryPoint); http.addFilterAt(filter, UsernamePasswordAuthenticationFilter.class); return http.build(); } @@ -147,12 +160,47 @@ public class WebSecurityConfig { @Override public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException { log.info("用户「{}」登录成功", authentication.getName()); - String token = UUID.randomUUID().toString(); - tokenRedisStorage.put(token, authentication); + + // 生成双重Token + String accessToken = UUID.randomUUID().toString(); + String refreshToken = UUID.randomUUID().toString(); + + // 存储Access Token + tokenRedisStorage.putAccessToken(accessToken, authentication); + + // 存储Refresh Token + String deviceId = extractDeviceId(request); + tokenRedisStorage.putRefreshToken(refreshToken, authentication.getName(), deviceId); + + // 构造登录响应 + LoginResponse loginResponse = new LoginResponse( + accessToken, + refreshToken, + "Bearer", + 1800, // 30分钟,秒数 + authentication.getName() + ); + response.setContentType("application/json;charset=UTF-8"); - R data = new R<>(HttpServletResponse.SC_OK, "登录成功", true, token); + R data = new R<>(HttpServletResponse.SC_OK, "登录成功", true, loginResponse); objectMapper.writeValue(response.getWriter(), data); } + + private String extractDeviceId(HttpServletRequest request) { + // 尝试从User-Agent提取设备信息 + String userAgent = request.getHeader("User-Agent"); + if (userAgent != null) { + // 简单的设备识别逻辑,生产环境可以使用更复杂的识别算法 + if (userAgent.contains("Mobile") || userAgent.contains("Android") || userAgent.contains("iPhone")) { + return "mobile-" + request.getRemoteAddr(); + } else if (userAgent.contains("Tablet") || userAgent.contains("iPad")) { + return "tablet-" + request.getRemoteAddr(); + } else { + return "desktop-" + request.getRemoteAddr(); + } + } + return "unknown-" + request.getRemoteAddr(); + } } private record CustomAccessDeniedHandler(ObjectMapper objectMapper) implements AccessDeniedHandler { diff --git a/zkh-web/src/main/java/vip/jcfd/web/config/props/SecurityProps.java b/zkh-web/src/main/java/vip/jcfd/web/config/props/SecurityProps.java index 19df022..4140841 100644 --- a/zkh-web/src/main/java/vip/jcfd/web/config/props/SecurityProps.java +++ b/zkh-web/src/main/java/vip/jcfd/web/config/props/SecurityProps.java @@ -9,7 +9,11 @@ public class SecurityProps { private String[] ignoreUrls; - private Duration duration; + private Duration duration; // Legacy access token过期时间(兼容性) + + private Duration accessTokenDuration; + + private Duration refreshTokenDuration; public String[] getIgnoreUrls() { return ignoreUrls; @@ -26,4 +30,20 @@ public class SecurityProps { public void setDuration(Duration duration) { this.duration = duration; } + + public Duration getAccessTokenDuration() { + return accessTokenDuration != null ? accessTokenDuration : duration; + } + + public void setAccessTokenDuration(Duration accessTokenDuration) { + this.accessTokenDuration = accessTokenDuration; + } + + public Duration getRefreshTokenDuration() { + return refreshTokenDuration != null ? refreshTokenDuration : Duration.ofDays(7); + } + + public void setRefreshTokenDuration(Duration refreshTokenDuration) { + this.refreshTokenDuration = refreshTokenDuration; + } } diff --git a/zkh-web/src/main/java/vip/jcfd/web/controller/AuthController.java b/zkh-web/src/main/java/vip/jcfd/web/controller/AuthController.java new file mode 100644 index 0000000..5f66bb5 --- /dev/null +++ b/zkh-web/src/main/java/vip/jcfd/web/controller/AuthController.java @@ -0,0 +1,122 @@ +package vip.jcfd.web.controller; + +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.validation.Valid; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.web.bind.annotation.*; +import vip.jcfd.common.core.R; +import vip.jcfd.common.dto.TokenRefreshRequest; +import vip.jcfd.common.dto.TokenRefreshResponse; +import vip.jcfd.web.auth.RefreshTokenAuthenticationToken; +import vip.jcfd.web.redis.TokenRedisStorage; + +/** + * 认证控制器 + * 处理登录、登出、Token刷新等认证相关操作 + */ +@RestController +@RequestMapping("auth") +@Tag(name = "认证管理", description = "认证相关接口") +public class AuthController { + private static final Logger log = LoggerFactory.getLogger(AuthController.class); + private final AuthenticationManager authenticationManager; + private final TokenRedisStorage tokenRedisStorage; + + public AuthController(AuthenticationManager authenticationManager, TokenRedisStorage tokenRedisStorage) { + this.authenticationManager = authenticationManager; + this.tokenRedisStorage = tokenRedisStorage; + } + + @PostMapping("/refresh-token") + @Operation(summary = "刷新Token", description = "使用Refresh Token获取新的Access Token和Refresh Token") + public R refreshToken( + @Valid @RequestBody TokenRefreshRequest request) { + + log.info("Token刷新请求: {}", request.refreshToken().substring(0, Math.min(10, request.refreshToken().length())) + "..."); + + // 验证refresh token + if (tokenRedisStorage.existsRefreshToken(request.refreshToken())) { + log.warn("Refresh Token无效或已过期"); + return R.error("Refresh Token无效或已过期"); + } + + // 获取refresh token信息 + TokenRedisStorage.RefreshTokenInfo tokenInfo = tokenRedisStorage.getRefreshTokenInfo(request.refreshToken()); + if (tokenInfo == null) { + log.warn("无法解析Refresh Token信息"); + return R.error("Refresh Token信息无效"); + } + + try { + // 重新认证用户 + RefreshTokenAuthenticationToken authToken = + new RefreshTokenAuthenticationToken(tokenInfo.username(), ""); + Authentication authentication = authenticationManager.authenticate(authToken); + + // 使用TokenRedisStorage的刷新方法 + TokenRedisStorage.TokenRefreshResult refreshResult = + tokenRedisStorage.refreshAccessToken(request.refreshToken(), authentication); + + if (refreshResult == null) { + return R.error("Token刷新失败"); + } + + log.info("Token刷新成功,用户: {}", tokenInfo.username()); + + TokenRefreshResponse response = new TokenRefreshResponse( + refreshResult.accessToken(), + refreshResult.refreshToken(), + "Bearer" + ); + + return R.success(response); + + } catch (AuthenticationException e) { + log.error("用户认证失败,用户: {}", tokenInfo.username(), e); + return R.error("用户认证失败"); + } catch (Exception e) { + log.error("Token刷新失败", e); + return R.serverError("Token刷新失败"); + } + } + + @PostMapping("/logout") + @Operation(summary = "登出", description = "登出当前设备或所有设备") + public R logout( + @Parameter(description = "是否登出所有设备") @RequestParam(value = "all", defaultValue = "false") boolean all, + HttpServletRequest request) { + + String header = request.getHeader(HttpHeaders.AUTHORIZATION); + if (header != null && header.startsWith("Bearer ")) { + String accessToken = header.substring(7); + Authentication authentication = tokenRedisStorage.get(accessToken); + + if (authentication != null) { + String username = authentication.getName(); + log.info("用户「{}」登出请求,all: {}", username, all); + + if (all) { + // 登出所有设备 + tokenRedisStorage.removeByUserName(username); + log.info("用户「{}」已从所有设备登出", username); + } else { + // 登出当前设备 + tokenRedisStorage.removeAccessToken(accessToken); + log.info("用户「{}」已从当前设备登出", username); + } + + return R.success("登出成功"); + } + } + + return R.error("无效的Token"); + } +} diff --git a/zkh-web/src/main/java/vip/jcfd/web/redis/TokenRedisStorage.java b/zkh-web/src/main/java/vip/jcfd/web/redis/TokenRedisStorage.java index c86848c..3c719d4 100644 --- a/zkh-web/src/main/java/vip/jcfd/web/redis/TokenRedisStorage.java +++ b/zkh-web/src/main/java/vip/jcfd/web/redis/TokenRedisStorage.java @@ -1,97 +1,170 @@ package vip.jcfd.web.redis; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.security.core.Authentication; import org.springframework.util.CollectionUtils; +import vip.jcfd.common.core.BizException; import java.time.Duration; -import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.UUID; public class TokenRedisStorage extends RedisTemplate { private final static Logger logger = LoggerFactory.getLogger(TokenRedisStorage.class); - private final Duration expire; + private final Duration accessTokenExpire; + private final Duration refreshTokenExpire; private final StringRedisTemplate stringRedisTemplate; + private final ObjectMapper objectMapper; - private static final String TOKEN_KEY_PREFIX = "TOKEN:"; - private static final String TOKEN_LIST_KEY_PREFIX = TOKEN_KEY_PREFIX + "LIST:"; + private static final String ACCESS_TOKEN_KEY_PREFIX = "ACCESS_TOKEN:"; + private static final String REFRESH_TOKEN_KEY_PREFIX = "REFRESH_TOKEN:"; + private static final String USER_REFRESH_TOKENS_PREFIX = "USER_REFRESH_TOKENS:"; - public TokenRedisStorage(Duration expire, StringRedisTemplate stringRedisTemplate) { - this.expire = expire; + public TokenRedisStorage(Duration accessTokenExpire, Duration refreshTokenExpire, StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper) { + this.accessTokenExpire = accessTokenExpire; + this.refreshTokenExpire = refreshTokenExpire; this.stringRedisTemplate = stringRedisTemplate; + this.objectMapper = objectMapper; } + // Access Token相关方法 public Authentication get(String token) { - return opsForValue().get(TOKEN_KEY_PREFIX + token); + return opsForValue().get(ACCESS_TOKEN_KEY_PREFIX + token); } + public void putAccessToken(String accessToken, Authentication authentication) { + opsForValue().set(ACCESS_TOKEN_KEY_PREFIX + accessToken, authentication, accessTokenExpire); + } + + public boolean existsAccessToken(String token) { + return opsForValue().get(ACCESS_TOKEN_KEY_PREFIX + token) != null; + } + + public void removeAccessToken(String accessToken) { + expire(ACCESS_TOKEN_KEY_PREFIX + accessToken, Duration.ZERO); + } + + // Refresh Token相关方法 + public void putRefreshToken(String refreshToken, String username, String deviceId) { + // 存储refresh token信息 + RefreshTokenInfo tokenInfo = new RefreshTokenInfo(username, deviceId); + try { + stringRedisTemplate.opsForValue().set(REFRESH_TOKEN_KEY_PREFIX + refreshToken, objectMapper.writeValueAsString(tokenInfo), refreshTokenExpire); + // 添加到用户的refresh token列表 + stringRedisTemplate.opsForList().leftPush(USER_REFRESH_TOKENS_PREFIX + username, refreshToken); + } catch (JsonProcessingException e) { + logger.error("序列化tokenInfo出错", e); + throw new BizException("序列化tokenInfo出错", e); + } + + } + + public RefreshTokenInfo getRefreshTokenInfo(String refreshToken) { + String json = stringRedisTemplate.opsForValue().get(REFRESH_TOKEN_KEY_PREFIX + refreshToken); + if (json != null) { + try { + return objectMapper.readValue(json, RefreshTokenInfo.class); + } catch (JsonProcessingException e) { + logger.error("反序列化tokenInfo出错", e); + throw new BizException("反序列化tokenInfo出错", e); + } + } + return null; + } + + public boolean existsRefreshToken(String refreshToken) { + return stringRedisTemplate.opsForValue().get(REFRESH_TOKEN_KEY_PREFIX + refreshToken) == null; + } + + public void removeRefreshToken(String refreshToken) { + RefreshTokenInfo tokenInfo = getRefreshTokenInfo(refreshToken); + if (tokenInfo != null) { + // 从用户列表中移除 + stringRedisTemplate.opsForList().remove(USER_REFRESH_TOKENS_PREFIX + tokenInfo.username(), 0, refreshToken); + } + // 删除refresh token + expire(REFRESH_TOKEN_KEY_PREFIX + refreshToken, Duration.ZERO); + } + + // 兼容性方法 - 保持向后兼容 public void put(String token, Authentication authentication) { - opsForValue().set(TOKEN_KEY_PREFIX + token, authentication, expire); - stringRedisTemplate.opsForList().leftPush(TOKEN_LIST_KEY_PREFIX + authentication.getName(), token); + putAccessToken(token, authentication); } public void remove(String token) { - Authentication authentication = get(token); - if (authentication != null) { - stringRedisTemplate.opsForList().remove(TOKEN_LIST_KEY_PREFIX + authentication.getName(), 0, token); - expire(TOKEN_KEY_PREFIX + token, Duration.ZERO); - } + removeAccessToken(token); } public boolean exists(String token) { - return opsForValue().get(TOKEN_KEY_PREFIX + token) != null; + return existsAccessToken(token); } public void removeByUserName(String username) { - List range = stringRedisTemplate.opsForList().range(TOKEN_LIST_KEY_PREFIX + username, 0, -1); - if (CollectionUtils.isEmpty(range)) { - return; + // 清理所有refresh tokens + List refreshTokens = stringRedisTemplate.opsForList().range(USER_REFRESH_TOKENS_PREFIX + username, 0, -1); + if (!CollectionUtils.isEmpty(refreshTokens)) { + for (String refreshToken : refreshTokens) { + expire(REFRESH_TOKEN_KEY_PREFIX + refreshToken, Duration.ZERO); + } } - for (String s : range) { - expire(TOKEN_KEY_PREFIX + s, Duration.ZERO); - } - expire(TOKEN_LIST_KEY_PREFIX + username, Duration.ZERO); + expire(USER_REFRESH_TOKENS_PREFIX + username, Duration.ZERO); } - private record TokenStorage(String key, Set tokens) { - public void addToken(String token) { - tokens.add(token); + // Token刷新相关方法 + public TokenRefreshResult refreshAccessToken(String refreshToken, Authentication newAuthentication) { + if (existsRefreshToken(refreshToken)) { + return null; } + + RefreshTokenInfo tokenInfo = getRefreshTokenInfo(refreshToken); + String newAccessToken = UUID.randomUUID().toString(); + String newRefreshToken = UUID.randomUUID().toString(); + String username = tokenInfo.username(); + + // 存储新的access token + putAccessToken(newAccessToken, newAuthentication); + + // 删除旧的refresh token + removeRefreshToken(refreshToken); + + // 创建并存储新的refresh token + putRefreshToken(newRefreshToken, username, tokenInfo.deviceId()); + + return new TokenRefreshResult(newAccessToken, newRefreshToken, username); } public void clearExpiredTokens() { logger.info("开始清理过期token"); - Set keys = keys(TOKEN_LIST_KEY_PREFIX + "*"); - if (CollectionUtils.isEmpty(keys)) { - logger.info("清理过期token完成"); - return; - } - List tokenStorages = new ArrayList<>(); - for (String key : keys) { - List range = stringRedisTemplate.opsForList().range(key, 0, -1); - if (CollectionUtils.isEmpty(range)) { - continue; - } - TokenStorage tokenStorage = new TokenStorage(key, new HashSet<>()); - tokenStorages.add(tokenStorage); - for (String token : range) { - if (!exists(token)) { - tokenStorage.addToken(token); + + // 清理过期的refresh tokens + Set refreshKeys = keys(USER_REFRESH_TOKENS_PREFIX + "*"); + if (!CollectionUtils.isEmpty(refreshKeys)) { + for (String key : refreshKeys) { + List refreshTokens = stringRedisTemplate.opsForList().range(key, 0, -1); + if (!CollectionUtils.isEmpty(refreshTokens)) { + for (String refreshToken : refreshTokens) { + if (existsRefreshToken(refreshToken)) { + stringRedisTemplate.opsForList().remove(key, 0, refreshToken); + } + } } } } - logger.info("收集过期token完成,共{}个token", tokenStorages.stream().map(TokenStorage::tokens).mapToInt(Set::size).sum()); - for (TokenStorage tokenStorage : tokenStorages) { - for (String token : tokenStorage.tokens) { - stringRedisTemplate.opsForList().remove(tokenStorage.key, 0, token); - } - } logger.info("清理过期token完成"); } + + // Refresh Token信息类 + public record RefreshTokenInfo(String username, String deviceId) { + } + + // Token刷新结果类 + public record TokenRefreshResult(String accessToken, String refreshToken, String username) { + } }