保护你的 API 免受滥用至关重要。速率限制是 API 安全的关键。它可以防止拒绝服务攻击、管理资源并确保客户端之间的公平使用。Spring Boot 3 和 Bucket4j 结合提供了一个强大且灵活的方式来为你的应用程序添加速率限制。
在本文中,我们将探讨如何在 Spring Boot 3 应用程序中使用 Bucket4j 开发速率限制功能。我们将介绍不同的方法,并提供实用的示例,供你根据需求进行调整。
先决条件
在开始之前,请确保你具备以下条件:
• Java 17 或更高版本。
• 对 Java、Spring Boot 和 API 开发有基本了解。
实现
第一步是将所需的依赖项添加到你的 pom.xml 或 build.gradle 中。
<dependency><groupId>com.bucket4j</groupId><artifactId>bucket4j-core</artifactId><version>8.3.0</version>
</dependency>
<dependency><groupId>com.bucket4j</groupId><artifactId>bucket4j-caffeine</artifactId><version>8.3.0</version>
</dependency>
<dependency><groupId>com.github.ben-manes.caffeine</groupId><artifactId>caffeine</artifactId><version>3.1.8</version>
</dependency>
我们不会直接跳到最终代码,而是逐步构建速率限制功能。让我们从创建一个基本的 REST 控制器开始。
@RestController
@RequestMapping("/api")
public class RateLimitedController {@GetMapping("/greeting")public String getGreeting() {return "Hello, World!";}
}
接下来,我们需要配置速率限制。
@Configuration
public class RateLimitConfig {@Beanpublic Bucket createNewBucket() {long overdraft = 50;Refill refill = Refill.intervally(40, Duration.ofMinutes(1));Bandwidth limit = Bandwidth.classic(overdraft, refill);return Bucket.builder().addLimit(limit).build();}
}
现在,我们需要设置一个速率限制拦截器。
@Component
@RequiredArgsConstructor
public class RateLimitInterceptor implements HandlerInterceptor {private final Bucket bucket;@Overridepublic boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);if (probe.isConsumed()) {response.addHeader("X-Rate-Limit-Remaining", String.valueOf(probe.getRemainingTokens()));return true;}long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000;response.addHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill));response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(),"You have exhausted your API Request Quota");return false;}
}
目前,我们还没有注册我们的拦截器,让我们来解决这个问题。
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {private final RateLimitInterceptor interceptor;public WebMvcConfig(RateLimitInterceptor interceptor) {this.interceptor = interceptor;}@Overridepublic void addInterceptors(InterceptorRegistry registry) {registry.addInterceptor(interceptor).addPathPatterns("/api/**");}
}
我们已经实现了一个基本的速率限制器。这个基本版本并不适合实际生产环境。
IP 基础的速率限制
IP 基础的速率限制更接近实际生产场景。IP 限制提供了更细粒度的控制。
@Component
public class IpBasedRateLimitInterceptor implements HandlerInterceptor {private final Cache<String, Bucket> cache;public IpBasedRateLimitInterceptor() {this.cache = Caffeine.newBuilder().expireAfterWrite(1, TimeUnit.SECONDS).build();}@Overridepublic boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {String ip = getClientIP(request);Bucket bucket = cache.get(ip, this::newBucket);ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);if (probe.isConsumed()) {response.addHeader("X-Rate-Limit-Remaining", String.valueOf(probe.getRemainingTokens()));return true;}long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000;response.addHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill));response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(),"Rate limit exceeded. Try again in " + waitForRefill + " seconds");return false;}private String getClientIP(HttpServletRequest request) {String xfHeader = request.getHeader("X-Forwarded-For");if (xfHeader == null) {return request.getRemoteAddr();}return xfHeader.split(",")[0];}private Bucket newBucket(String ip) {return Bucket.builder().addLimit(Bandwidth.classic(10, Refill.intervally(10, Duration.ofMinutes(1)))).build();}
}
当然,我们需要单元测试来验证我们的实现是否有效。
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
class RateLimitedControllerTest {@LocalServerPortprivate int port;@Autowiredprivate TestRestTemplate restTemplate;@Testvoid whenExceedingRateLimit_thenReceive429() {String url = "http://localhost:" + port + "/api/greeting";// 发送 10 个请求(超过我们的限制 9 次)for (int i = 0; i < 10; i++) {ResponseEntity<String> response = restTemplate.getForEntity(url, String.class);if (i < 10) {assertEquals(HttpStatus.OK, response.getStatusCode());} else {assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode());}}}
}
基于系统负载的动态速率限制
最后但同样重要的是,让我们再构建一个速率限制器。这个速率限制器将根据应用程序的负载来限制请求。
@Slf4j
@Component
public class SystemMetricsCollector {private final OperatingSystemMXBean osBean;public SystemMetricsCollector() {this.osBean = ManagementFactory.getOperatingSystemMXBean();}public SystemMetrics collectMetrics() {double cpuLoad = getProcessCpuLoad();long freeMemory = Runtime.getRuntime().freeMemory();long totalMemory = Runtime.getRuntime().totalMemory();double memoryUsage = 1.0 - (double) freeMemory / totalMemory;return new SystemMetrics(cpuLoad, memoryUsage);}private double getProcessCpuLoad() {if (osBean instanceof com.sun.management.OperatingSystemMXBean) {return ((com.sun.management.OperatingSystemMXBean) osBean).getProcessCpuLoad();}return osBean.getSystemLoadAverage();}
}
以及:
@Data
@AllArgsConstructor
public class SystemMetrics {private double cpuLoad;private double memoryUsage;
}
然后,我们需要创建速率限制的计算组件。
@Component
@Slf4j
public class DynamicRateLimitCalculator {private static final int BASE_LIMIT = 100;private static final double CPU_THRESHOLD_HIGH = 0.8;private static final double CPU_THRESHOLD_MEDIUM = 0.5;private static final double MEMORY_THRESHOLD_HIGH = 0.8;private static final double MEMORY_THRESHOLD_MEDIUM = 0.5;public RateLimitConfig calculateLimit(SystemMetrics metrics) {int limit = BASE_LIMIT;// 根据 CPU 负载调整限制limit = adjustLimitBasedOnCpu(limit, metrics.getCpuLoad());// 根据内存使用率调整限制limit = adjustLimitBasedOnMemory(limit, metrics.getMemoryUsage());Duration refillDuration = calculateRefillDuration(metrics);log.debug("Calculated rate limit: {}/{}s", limit, refillDuration.getSeconds());return new RateLimitConfig(limit, refillDuration);}private int adjustLimitBasedOnCpu(int currentLimit, double cpuLoad) {if (cpuLoad > CPU_THRESHOLD_HIGH) {return (int) (currentLimit * 0.3); // 严重减少} else if (cpuLoad > CPU_THRESHOLD_MEDIUM) {return (int) (currentLimit * 0.6); // 适度减少}return currentLimit;}private int adjustLimitBasedOnMemory(int currentLimit, double memoryUsage) {if (memoryUsage > MEMORY_THRESHOLD_HIGH) {return (int) (currentLimit * 0.4);} else if (memoryUsage > MEMORY_THRESHOLD_MEDIUM) {return (int) (currentLimit * 0.7);}return currentLimit;}private Duration calculateRefillDuration(SystemMetrics metrics) {double maxLoad = Math.max(metrics.getCpuLoad(), metrics.getMemoryUsage());if (maxLoad > 0.8) {return Duration.ofMinutes(2);} else if (maxLoad > 0.5) {return Duration.ofMinutes(1);}return Duration.ofSeconds(30);}
}@Data
@AllArgsConstructor
public class RateLimitConfig {private int limit;private Duration refillDuration;
}
让我们创建一个灵活的速率限制器,它将作为处理程序拦截器。
@Slf4j
@Component
public class DynamicRateLimitInterceptor implements HandlerInterceptor, RateLimitConfigProvider {private final Cache<String, Bucket> bucketCache;private final SystemMetricsCollector metricsCollector;private final DynamicRateLimitCalculator calculator;private final AtomicReference<RateLimitConfig> currentConfig;private final ScheduledExecutorService scheduler;private final RateLimitMetrics metrics;public DynamicRateLimitInterceptor(SystemMetricsCollector metricsCollector,DynamicRateLimitCalculator calculator, MeterRegistry meterRegistry) {this.metricsCollector = metricsCollector;this.calculator = calculator;this.currentConfig = new AtomicReference<>(new RateLimitConfig(100, Duration.ofMinutes(1)));this.bucketCache = Caffeine.newBuilder().expireAfterWrite(1, TimeUnit.HOURS).build();this.scheduler = Executors.newSingleThreadScheduledExecutor();this.metrics = new RateLimitMetrics(meterRegistry, this);startMetricsUpdateTask();}private void startMetricsUpdateTask() {scheduler.scheduleAtFixedRate(this::updateRateLimitConfig,0,10,TimeUnit.SECONDS);}private void updateRateLimitConfig() {try {SystemMetrics metrics = metricsCollector.collectMetrics();RateLimitConfig newConfig = calculator.calculateLimit(metrics);RateLimitConfig oldConfig = currentConfig.get();if (hasSignificantChange(oldConfig, newConfig)) {currentConfig.set(newConfig);log.info("Rate limit updated: {}/{}s",newConfig.getLimit(),newConfig.getRefillDuration().getSeconds());// Clear cache to force bucket recreation with new limitsbucketCache.invalidateAll();}} catch (Exception e) {log.error("Error updating rate limit config", e);}}private boolean hasSignificantChange(RateLimitConfig oldConfig,RateLimitConfig newConfig) {double limitChange = Math.abs(1.0 -(double) newConfig.getLimit() / oldConfig.getLimit());return limitChange > 0.2; // 20% change threshold}public RateLimitConfig getRateLimitConfig() {return this.currentConfig.get();}@Overridepublic boolean preHandle(HttpServletRequest request,HttpServletResponse response,Object handler) throws Exception {String path = request.getRequestURI();String method = request.getMethod();Timer.Sample timerSample = metrics.startTimer();boolean rateLimited = false;try {metrics.recordRequest();String clientId = getClientIdentifier(request);Bucket bucket = bucketCache.get(clientId, this::createBucket);ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);if (probe.isConsumed()) {addRateLimitHeaders(response, probe);return true;}metrics.incrementRateLimitExceeded();handleRateLimitExceeded(response, probe);return false;} finally {metrics.stopTimer(timerSample, path, method, rateLimited);}}private Bucket createBucket(String clientId) {RateLimitConfig config = currentConfig.get();return Bucket.builder().addLimit(Bandwidth.classic(config.getLimit(),Refill.intervally(config.getLimit(),config.getRefillDuration()))).build();}private String getClientIdentifier(HttpServletRequest request) {// Could combine multiple factors: IP, user ID, API key, etc.return request.getRemoteAddr();}private void addRateLimitHeaders(HttpServletResponse response,ConsumptionProbe probe) {RateLimitConfig config = currentConfig.get();response.addHeader("X-Rate-Limit-Limit",String.valueOf(config.getLimit()));response.addHeader("X-Rate-Limit-Remaining",String.valueOf(probe.getRemainingTokens()));response.addHeader("X-Rate-Limit-Reset",String.valueOf(probe.getNanosToWaitForRefill() /1_000_000_000));}private void handleRateLimitExceeded(HttpServletResponse response,ConsumptionProbe probe)throws IOException {response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());response.setContentType(MediaType.APPLICATION_JSON_VALUE);String errorMessage = String.format("Rate limit exceeded. Try again in %d seconds",probe.getNanosToWaitForRefill() / 1_000_000_000);response.getWriter().write(String.format("{\"error\": \"%s\", \"retryAfter\": %d}",errorMessage,probe.getNanosToWaitForRefill() / 1_000_000_000));}@PreDestroypublic void shutdown() {scheduler.shutdown();}
}
配置 Spring Boot 应用程序以使用速率限制器
现在我们需要配置 Spring Boot 应用程序以使用我们实现的速率限制器。
@Configuration
public class RateLimitConfig implements WebMvcConfigurer {@Autowiredprivate DynamicRateLimiter rateLimiter;@Overridepublic void addInterceptors(InterceptorRegistry registry) {registry.addInterceptor(rateLimiter).addPathPatterns("/api/**");}
}
通过上述配置,我们将 DynamicRateLimiter 注册为一个拦截器,并将其应用于所有以 /api 开头的请求路径。
创建自定义指标以监控应用程序
为了跟踪性能、负载、内存消耗等应用程序的各个方面,创建自定义指标是一个好主意。
以下是实现自定义指标的代码:
public class RateLimitMetrics {private final MeterRegistry meterRegistry;private final Counter rateLimitExceeded;private final Counter requestsTotal;private final Gauge currentLimit;public RateLimitMetrics(MeterRegistry registry,RateLimitConfigProvider configProvider) {this.meterRegistry = registry;this.rateLimitExceeded = Counter.builder("rate_limit.exceeded").description("Number of rate limit exceeded events").tag("type", "exceeded").register(registry);this.requestsTotal = Counter.builder("rate_limit.requests").description("Total number of requests processed").tag("type", "total").register(registry);this.currentLimit = Gauge.builder("rate_limit.current",configProvider,this::getCurrentLimit).description("Current rate limit value").tag("type", "limit").register(registry);}public Timer.Sample startTimer() {return Timer.start();}public void stopTimer(Timer.Sample sample, String path, String method, boolean rateLimited) {Timer timer = Timer.builder("rate_limit.request.duration").description("Request duration through rate limiter").tags("path", path,"method", method,"rate_limited", String.valueOf(rateLimited),"component", "rate_limiter").register(meterRegistry);sample.stop(timer);}public void incrementRateLimitExceeded() {rateLimitExceeded.increment();}public void recordRequest() {requestsTotal.increment();}private double getCurrentLimit(RateLimitConfigProvider provider) {return provider.getRateLimitConfig().getLimit();}public Map<String, Number> getCurrentMetrics() {return Map.of("rateLimitExceeded", rateLimitExceeded.count(),"totalRequests", requestsTotal.count(),"currentLimit", currentLimit.value());}
}
通过上述代码,我们创建了以下指标:
1、 rate_limit.exceeded:记录速率限制被触发的次数。
2、 rate_limit.requests:记录处理的请求总数。
3、 rate_limit.current:显示当前的速率限制值。
最佳实践和注意事项
1、 缓存实现:在生产环境中,使用分布式缓存(如 Redis)来实现集群环境中的速率限制。
2、 响应头:始终在响应头中包含速率限制信息,以帮助客户端管理其请求速率。
常见的头信息包括:
• X-Rate-Limit-Remaining:剩余的请求次数。
• X-Rate-Limit-Retry-After-Seconds:需要等待的秒数。
3、 错误处理:当用户超出速率限制时,提供清晰的错误信息。
4、 监控:设置指标以跟踪速率限制事件,并根据使用模式调整限制。
结论
本文展示了如何在 Spring Boot 3 应用程序中使用 Bucket4j 实现速率限制。我们介绍了三种方法:基于时间、基于 IP 地址和基于系统负载的速率限制。实际场景可能与本文中的示例有所不同。
速率限制是 API 安全策略的一部分。通过将其与其他安全措施结合使用,你可以构建强大的 API 保护机制。
欢迎关注 SpringForAll社区(spring4all.com),专注分享关于Spring的一切!关注公众号:SpringForAll社区
原创 s4a SpringForAll社区