API 网关安全:全面防护策略
核心概念
API 网关作为微服务架构的入口,是安全防护的第一道防线。合理的安全配置可以有效保护后端服务,防止各种攻击。
请求认证
// JWT 认证过滤器 @Component public class JwtAuthenticationFilter extends OncePerRequestFilter { private final JwtTokenProvider tokenProvider; public JwtAuthenticationFilter(JwtTokenProvider tokenProvider) { this.tokenProvider = tokenProvider; } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { String token = extractToken(request); if (token != null && tokenProvider.validateToken(token)) { String username = tokenProvider.getUsernameFromToken(token); UserDetails userDetails = userDetailsService.loadUserByUsername(username); UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities()); auth.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); SecurityContextHolder.getContext().setAuthentication(auth); } filterChain.doFilter(request, response); } private String extractToken(HttpServletRequest request) { String bearerToken = request.getHeader("Authorization"); if (bearerToken != null && bearerToken.startsWith("Bearer ")) { return bearerToken.substring(7); } return null; } } // JWT 令牌提供者 @Component public class JwtTokenProvider { private final String secretKey = "your-secret-key"; private final long validityInMilliseconds = 3600000; // 1小时 public String createToken(String username, List<String> roles) { Claims claims = Jwts.claims().setSubject(username); claims.put("roles", roles); Date now = new Date(); Date validity = new Date(now.getTime() + validityInMilliseconds); return Jwts.builder() .setClaims(claims) .setIssuedAt(now) .setExpiration(validity) .signWith(SignatureAlgorithm.HS256, secretKey) .compact(); } public boolean validateToken(String token) { try { Jwts.parser().setSigningKey(secretKey).parseClaimsJws(token); return true; } catch (JwtException | IllegalArgumentException e) { return false; } } public String getUsernameFromToken(String token) { return Jwts.parser() .setSigningKey(secretKey) .parseClaimsJws(token) .getBody() .getSubject(); } }请求限流
// 限流过滤器 @Component public class RateLimitFilter extends OncePerRequestFilter { private final RedisTemplate<String, String> redisTemplate; private static final String LIMIT_PREFIX = "rate:limit:"; private static final int MAX_REQUESTS = 100; private static final int TIME_WINDOW_SECONDS = 60; public RateLimitFilter(RedisTemplate<String, String> redisTemplate) { this.redisTemplate = redisTemplate; } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { String clientIp = getClientIp(request); String key = LIMIT_PREFIX + clientIp; String countStr = redisTemplate.opsForValue().get(key); int count = countStr != null ? Integer.parseInt(countStr) : 0; if (count >= MAX_REQUESTS) { response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value()); response.getWriter().write("Rate limit exceeded"); return; } redisTemplate.opsForValue().increment(key); redisTemplate.expire(key, TIME_WINDOW_SECONDS, TimeUnit.SECONDS); filterChain.doFilter(request, response); } private String getClientIp(HttpServletRequest request) { String ip = request.getHeader("X-Forwarded-For"); if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("Proxy-Client-IP"); } if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("WL-Proxy-Client-IP"); } if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) { ip = request.getRemoteAddr(); } if (ip != null && ip.contains(",")) { ip = ip.split(",")[0].trim(); } return ip; } } // 基于令牌桶的限流 @Component public class TokenBucketRateLimiter { private final RedisTemplate<String, String> redisTemplate; private static final String BUCKET_PREFIX = "token:bucket:"; public TokenBucketRateLimiter(RedisTemplate<String, String> redisTemplate) { this.redisTemplate = redisTemplate; } public boolean tryAcquire(String key, int capacity, int refillRatePerSecond) { String bucketKey = BUCKET_PREFIX + key; return redisTemplate.execute((RedisCallback<Boolean>) connection -> { long now = System.currentTimeMillis(); byte[] keyBytes = bucketKey.getBytes(StandardCharsets.UTF_8); byte[] value = connection.get(keyBytes); if (value == null) { // 初始化桶 String initialValue = now + "," + capacity; connection.set(keyBytes, initialValue.getBytes(StandardCharsets.UTF_8)); connection.expire(keyBytes, 3600); return true; } String[] parts = new String(value, StandardCharsets.UTF_8).split(","); long lastRefillTime = Long.parseLong(parts[0]); int tokens = Integer.parseInt(parts[1]); // 计算新增令牌 long elapsedSeconds = (now - lastRefillTime) / 1000; int newTokens = (int) (elapsedSeconds * refillRatePerSecond); tokens = Math.min(tokens + newTokens, capacity); if (tokens > 0) { tokens--; String newValue = now + "," + tokens; connection.set(keyBytes, newValue.getBytes(StandardCharsets.UTF_8)); return true; } return false; }); } }请求验证
// 请求参数验证 @RestControllerAdvice public class ValidationExceptionHandler { @ExceptionHandler(MethodArgumentNotValidException.class) public ResponseEntity<Map<String, Object>> handleValidationExceptions( MethodArgumentNotValidException ex) { Map<String, String> errors = new HashMap<>(); ex.getBindingResult().getAllErrors().forEach(error -> { String fieldName = ((FieldError) error).getField(); String errorMessage = error.getDefaultMessage(); errors.put(fieldName, errorMessage); }); Map<String, Object> response = new HashMap<>(); response.put("status", "error"); response.put("message", "Validation failed"); response.put("errors", errors); return ResponseEntity.badRequest().body(response); } @ExceptionHandler(ConstraintViolationException.class) public ResponseEntity<Map<String, Object>> handleConstraintViolation( ConstraintViolationException ex) { Map<String, String> errors = new HashMap<>(); ex.getConstraintViolations().forEach(violation -> { String fieldName = violation.getPropertyPath().toString(); String errorMessage = violation.getMessage(); errors.put(fieldName, errorMessage); }); Map<String, Object> response = new HashMap<>(); response.put("status", "error"); response.put("message", "Validation failed"); response.put("errors", errors); return ResponseEntity.badRequest().body(response); } } // 请求 DTO 验证示例 public class UserCreateRequest { @NotBlank(message = "Email is required") @Email(message = "Invalid email format") private String email; @NotBlank(message = "Password is required") @Size(min = 8, message = "Password must be at least 8 characters") private String password; @NotBlank(message = "Name is required") @Size(max = 100, message = "Name cannot exceed 100 characters") private String name; // getters and setters }安全头设置
// 安全响应头配置 @Component public class SecurityHeaderFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { // 设置安全头 response.setHeader("X-Content-Type-Options", "nosniff"); response.setHeader("X-Frame-Options", "DENY"); response.setHeader("X-XSS-Protection", "1; mode=block"); response.setHeader("Content-Security-Policy", "default-src 'self'"); response.setHeader("Strict-Transport-Security", "max-age=31536000; includeSubDomains"); response.setHeader("X-Permitted-Cross-Domain-Policies", "none"); response.setHeader("Referrer-Policy", "strict-origin-when-cross-origin"); // 移除服务器信息 response.setHeader("Server", ""); filterChain.doFilter(request, response); } } // CORS 配置 @Configuration public class CorsConfig { @Bean public WebMvcConfigurer corsConfigurer() { return new WebMvcConfigurer() { @Override public void addCorsMappings(CorsRegistry registry) { registry.addMapping("/api/**") .allowedOrigins("https://example.com") .allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS") .allowedHeaders("*") .allowCredentials(true) .maxAge(3600); } }; } }请求日志与审计
// 请求日志过滤器 @Component public class RequestLoggingFilter extends OncePerRequestFilter { private static final Logger logger = LoggerFactory.getLogger(RequestLoggingFilter.class); @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { long startTime = System.currentTimeMillis(); try { filterChain.doFilter(request, response); } finally { long duration = System.currentTimeMillis() - startTime; logger.info( "Request: {} {} - Status: {} - Duration: {}ms - IP: {}", request.getMethod(), request.getRequestURI(), response.getStatus(), duration, getClientIp(request) ); } } private String getClientIp(HttpServletRequest request) { String ip = request.getHeader("X-Forwarded-For"); return ip != null ? ip.split(",")[0].trim() : request.getRemoteAddr(); } } // 审计日志服务 @Service public class AuditLogService { private final AuditLogRepository auditLogRepository; public AuditLogService(AuditLogRepository auditLogRepository) { this.auditLogRepository = auditLogRepository; } @Async public void log(String action, String resourceType, String resourceId, String userId, String clientIp) { AuditLog auditLog = new AuditLog(); auditLog.setAction(action); auditLog.setResourceType(resourceType); auditLog.setResourceId(resourceId); auditLog.setUserId(userId); auditLog.setClientIp(clientIp); auditLog.setCreatedAt(LocalDateTime.now()); auditLogRepository.save(auditLog); } public List<AuditLog> getLogsByUserId(String userId, int limit) { return auditLogRepository.findByUserIdOrderByCreatedAtDesc(userId, limit); } } // 审计日志实体 @Entity @Table(name = "audit_logs") public class AuditLog { @Id @GeneratedValue(strategy = GenerationType.IDENTITY) private Long id; private String action; private String resourceType; private String resourceId; private String userId; private String clientIp; private LocalDateTime createdAt; // getters and setters }异常处理
// 全局异常处理器 @RestControllerAdvice public class GlobalExceptionHandler { private static final Logger logger = LoggerFactory.getLogger(GlobalExceptionHandler.class); @ExceptionHandler(Exception.class) public ResponseEntity<ErrorResponse> handleException(Exception ex) { logger.error("Unexpected error", ex); ErrorResponse response = new ErrorResponse( HttpStatus.INTERNAL_SERVER_ERROR.value(), "Internal server error", LocalDateTime.now() ); return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(response); } @ExceptionHandler(ResourceNotFoundException.class) public ResponseEntity<ErrorResponse> handleResourceNotFound(ResourceNotFoundException ex) { ErrorResponse response = new ErrorResponse( HttpStatus.NOT_FOUND.value(), ex.getMessage(), LocalDateTime.now() ); return ResponseEntity.status(HttpStatus.NOT_FOUND).body(response); } @ExceptionHandler(UnauthorizedException.class) public ResponseEntity<ErrorResponse> handleUnauthorized(UnauthorizedException ex) { ErrorResponse response = new ErrorResponse( HttpStatus.UNAUTHORIZED.value(), ex.getMessage(), LocalDateTime.now() ); return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body(response); } @ExceptionHandler(ForbiddenException.class) public ResponseEntity<ErrorResponse> handleForbidden(ForbiddenException ex) { ErrorResponse response = new ErrorResponse( HttpStatus.FORBIDDEN.value(), ex.getMessage(), LocalDateTime.now() ); return ResponseEntity.status(HttpStatus.FORBIDDEN).body(response); } } // 错误响应 public class ErrorResponse { private int status; private String message; private LocalDateTime timestamp; public ErrorResponse(int status, String message, LocalDateTime timestamp) { this.status = status; this.message = message; this.timestamp = timestamp; } // getters }API 安全最佳实践
- 使用 HTTPS:确保所有通信使用 HTTPS
- 认证与授权:使用 JWT 或 OAuth2 进行认证
- 输入验证:对所有输入进行严格验证
- 限流熔断:防止 API 被滥用
- 安全头:设置适当的安全响应头
- 日志审计:记录所有关键操作
- 异常处理:统一异常处理和错误响应
- 定期更新:及时更新依赖和补丁
实际应用场景
- API 网关安全:保护微服务入口
- OAuth2 认证:第三方应用授权
- API 限流:防止恶意请求
- 安全审计:追踪用户操作
总结
API 网关安全是微服务架构安全的重要组成部分。通过综合运用认证、限流、验证、日志等手段,可以构建安全可靠的 API 网关。
别叫我大神,叫我 Alex 就好。这其实可以更优雅一点,合理的安全配置让 API 网关变得更加安全和可靠。