分布式限流器 - 算法实现与代码
2025/11/13大约 8 分钟
分布式限流器 - 算法实现与代码
一、限流算法详解
3.1 令牌桶算法 (Token Bucket)
算法原理
令牌桶算法是一种常用的限流算法,核心思想是:
- 令牌生成:系统以恒定速率向桶中放入令牌
- 令牌消费:每个请求需要消耗一个令牌
- 令牌不足:桶中没有令牌时,请求被拒绝
- 桶有容量:令牌数量有上限,多余的令牌会被丢弃
时间轴: t0 t1 t2 t3 t4
↓ ↓ ↓ ↓ ↓
令牌桶: [●●●] → [●●●●] → [●●●] → [●●●●] → [●●●]
↑生成 ↑请求 ↑生成 ↑请求 ↑生成
令牌 消耗 令牌 消耗 令牌算法特点
优点:
- 允许突发流量:桶中积累的令牌可以应对短时间的流量突发
- 平滑限流:长期来看,流量被平滑到恒定速率
- 实现简单:逻辑清晰,易于理解和实现
缺点:
- 突发流量可能过大:如果桶容量设置过大,可能导致瞬时流量过高
适用场景:
- 全局限流:保护网关整体,允许一定突发
- 服务级限流:保护后端服务,提高资源利用率
- IP级限流:防止恶意攻击,允许正常用户的突发请求
Redis + Lua 实现
-- 令牌桶算法 Lua 脚本
-- KEYS[1]: 限流 key
-- ARGV[1]: 限流阈值(每秒令牌数)
local key = KEYS[1]
local limit = tonumber(ARGV[1])
-- 获取当前计数
local current = tonumber(redis.call('get', key) or '0')
if current < limit then
-- 令牌充足,计数+1
redis.call('incr', key)
-- 首次计数时设置过期时间
if current == 0 then
redis.call('expire', key, 1)
end
return 1 -- 允许通过
else
return 0 -- 拒绝请求
end实现细节:
- 计数器重置:使用 Redis 的
expire命令,1秒后自动重置计数器 - 原子性保证:Lua 脚本在 Redis 中原子执行,避免并发问题
- 简化实现:不维护真实的令牌桶,而是用计数器模拟
Java 调用代码:
public boolean tryAcquireWithTokenBucket(String key, RateLimitConfig config) {
String script =
"local key = KEYS[1]\n" +
"local limit = tonumber(ARGV[1])\n" +
"local current = tonumber(redis.call('get', key) or '0')\n" +
"if current < limit then\n" +
" redis.call('incr', key)\n" +
" if current == 0 then\n" +
" redis.call('expire', key, 1)\n" +
" end\n" +
" return 1\n" +
"else\n" +
" return 0\n" +
"end";
Long result = redisTemplate.execute(
new DefaultRedisScript<>(script, Long.class),
Collections.singletonList(key),
config.getLimitCount().toString()
);
return result != null && result == 1;
}3.2 滑动窗口算法 (Sliding Window)
算法原理
滑动窗口算法将时间划分为多个窗口,统计窗口内的请求数:
- 时间窗口:定义一个时间窗口(如1秒)
- 请求记录:记录每个请求的时间戳
- 窗口滑动:随着时间推移,窗口向前滑动
- 计数判断:统计当前窗口内的请求数,超过阈值则拒绝
时间轴: |----1s----|----1s----|----1s----|
请求: ●●● ●●●● ●●
↑ ↑ ↑
窗口1 窗口2 窗口3
(3个) (4个) (2个)滑动窗口示例:
当前时间:1000ms
窗口大小:1000ms
窗口范围:[0ms, 1000ms]
请求时间戳:100ms, 200ms, 300ms, 1100ms
窗口内请求:100ms, 200ms, 300ms (3个)
窗口外请求:1100ms (被移除)算法特点
优点:
- 精确控制:严格限制时间窗口内的请求数
- 防止突发:不允许超过阈值的突发流量
- 实时性好:窗口实时滑动,响应及时
缺点:
- 存储开销:需要记录每个请求的时间戳
- 计算复杂:需要清理过期数据,计算窗口内请求数
适用场景:
- 接口级限流:保护核心接口(如登录、支付),需要精确控制
- 防止暴力破解:严格限制登录尝试次数
- 防止爬虫:精确控制单个IP的请求频率
Redis + Lua 实现
-- 滑动窗口算法 Lua 脚本
-- KEYS[1]: 限流 key
-- ARGV[1]: 限流阈值
-- ARGV[2]: 时间窗口(秒)
-- ARGV[3]: 当前时间戳(毫秒)
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = tonumber(ARGV[3])
-- 计算窗口起始时间
local expire_time = current - window * 1000
-- 删除窗口外的过期数据
redis.call('zremrangebyscore', key, 0, expire_time)
-- 统计当前窗口内的请求数
local count = redis.call('zcard', key)
if count < limit then
-- 未达到限流阈值,添加当前请求
redis.call('zadd', key, current, current)
-- 设置过期时间(窗口大小 + 1秒)
redis.call('expire', key, window + 1)
return 1 -- 允许通过
else
return 0 -- 拒绝请求
end实现细节:
- 数据结构:使用 Redis 的 Sorted Set,score 为时间戳
- 过期清理:使用
zremrangebyscore删除窗口外的数据 - 计数统计:使用
zcard统计窗口内的请求数 - 自动过期:设置 key 的过期时间,避免内存泄漏
Java 调用代码:
public boolean tryAcquireWithSlidingWindow(String key, RateLimitConfig config) {
String script =
"local key = KEYS[1]\n" +
"local limit = tonumber(ARGV[1])\n" +
"local window = tonumber(ARGV[2])\n" +
"local current = tonumber(ARGV[3])\n" +
"local expire_time = current - window * 1000\n" +
"redis.call('zremrangebyscore', key, 0, expire_time)\n" +
"local count = redis.call('zcard', key)\n" +
"if count < limit then\n" +
" redis.call('zadd', key, current, current)\n" +
" redis.call('expire', key, window + 1)\n" +
" return 1\n" +
"else\n" +
" return 0\n" +
"end";
Long result = redisTemplate.execute(
new DefaultRedisScript<>(script, Long.class),
Collections.singletonList(key),
config.getLimitCount().toString(),
config.getTimeWindow().toString(),
String.valueOf(System.currentTimeMillis())
);
return result != null && result == 1;
}3.3 算法选择策略
| 场景 | 推荐算法 | 理由 |
|---|---|---|
| 全局限流 | 令牌桶 | 允许突发流量,提高资源利用率,避免过度限流 |
| 服务级限流 | 令牌桶 | 保护后端服务,允许短时突发,提升用户体验 |
| 接口级限流 | 滑动窗口 | 精确控制核心接口,防止接口被打垮 |
| IP级限流 | 令牌桶 | 防止恶意攻击,允许正常用户的突发请求 |
| 登录接口 | 滑动窗口 | 防止暴力破解,严格限制尝试次数 |
| 支付接口 | 滑动窗口 | 保护资金安全,精确控制请求频率 |
选择原则:
- 需要允许突发流量 → 令牌桶
- 需要精确控制 → 滑动窗口
- 性能要求高 → 令牌桶(实现更简单,性能更好)
- 安全要求高 → 滑动窗口(控制更严格)
二、核心代码实现
4.1 DistributedRateLimiter 完整实现
@Component
@Slf4j
public class DistributedRateLimiter {
@Resource
private RedisTemplate<String, Object> redisTemplate;
// 本地限流器缓存
private final Map<String, RateLimiter> localLimiters = new ConcurrentHashMap<>();
// 配置缓存
private final Map<String, RateLimitConfig> configCache = new ConcurrentHashMap<>();
// Lua 脚本缓存
private final RedisScript<Long> tokenBucketScript;
private final RedisScript<Long> slidingWindowScript;
public DistributedRateLimiter() {
// 加载 Lua 脚本
this.tokenBucketScript = loadScript("token_bucket.lua");
this.slidingWindowScript = loadScript("sliding_window.lua");
}
/**
* 尝试获取令牌
* @param key 限流 key
* @param config 限流配置
* @return true-允许通过,false-拒绝请求
*/
public boolean tryAcquire(String key, RateLimitConfig config) {
// 1. 检查配置是否启用
if (!config.getEnabled()) {
return true;
}
// 2. 本地限流检查
if (!tryAcquireLocal(key, config)) {
log.debug("本地限流拒绝: key={}", key);
return false;
}
// 3. Redis 分布式限流检查
try {
boolean result = tryAcquireDistributed(key, config);
if (!result) {
log.debug("Redis限流拒绝: key={}", key);
}
return result;
} catch (Exception e) {
log.error("Redis限流异常,降级为本地限流: key={}", key, e);
return true; // 降级策略:Redis异常时放行
}
}
/**
* 本地限流检查
*/
private boolean tryAcquireLocal(String key, RateLimitConfig config) {
RateLimiter limiter = getOrCreateLocalLimiter(key, config);
return limiter.tryAcquire();
}
/**
* Redis 分布式限流检查
*/
private boolean tryAcquireDistributed(String key, RateLimitConfig config) {
String redisKey = "rate_limit:" + key;
if ("TOKEN_BUCKET".equals(config.getStrategy())) {
return tryAcquireWithTokenBucket(redisKey, config);
} else if ("SLIDING_WINDOW".equals(config.getStrategy())) {
return tryAcquireWithSlidingWindow(redisKey, config);
} else {
throw new IllegalArgumentException("不支持的限流策略: " + config.getStrategy());
}
}
/**
* 令牌桶算法
*/
private boolean tryAcquireWithTokenBucket(String key, RateLimitConfig config) {
Long result = redisTemplate.execute(
tokenBucketScript,
Collections.singletonList(key),
config.getLimitCount().toString()
);
return result != null && result == 1;
}
/**
* 滑动窗口算法
*/
private boolean tryAcquireWithSlidingWindow(String key, RateLimitConfig config) {
Long result = redisTemplate.execute(
slidingWindowScript,
Collections.singletonList(key),
config.getLimitCount().toString(),
config.getTimeWindow().toString(),
String.valueOf(System.currentTimeMillis())
);
return result != null && result == 1;
}
/**
* 获取或创建本地限流器
*/
private RateLimiter getOrCreateLocalLimiter(String key, RateLimitConfig config) {
return localLimiters.computeIfAbsent(key, k -> {
// 本地限流器设置为配置值的 1.2 倍
double permitsPerSecond = config.getLimitCount() * 1.2 / config.getTimeWindow();
log.info("创建本地限流器: key={}, permitsPerSecond={}", key, permitsPerSecond);
return RateLimiter.create(permitsPerSecond);
});
}
/**
* 更新配置
*/
public void updateConfig(String key, RateLimitConfig config) {
configCache.put(key, config);
localLimiters.remove(key); // 清除旧的本地限流器
log.info("更新限流配置: key={}, config={}", key, config);
}
/**
* 删除配置
*/
public void removeConfig(String key) {
configCache.remove(key);
localLimiters.remove(key);
log.info("删除限流配置: key={}", key);
}
/**
* 获取配置
*/
public RateLimitConfig getConfig(String key) {
return configCache.get(key);
}
}4.2 RateLimitPreHandler 完整实现
@Component
@Order(10)
@Slf4j
public class RateLimitPreHandler implements CustomPreHandler {
@Resource
private DistributedRateLimiter rateLimiter;
@Resource
private RateLimitConfigService configService;
@Override
public Result<Void> handle(HttpStatement httpStatement, FullHttpRequest request) {
// 1. 全局限流检查
if (!checkGlobalLimit()) {
return Result.fail(ResultCode.RATE_LIMIT_EXCEEDED, "全局限流,请稍后重试");
}
// 2. 服务级限流检查
String serviceName = httpStatement.getServiceName();
if (!checkServiceLimit(serviceName)) {
return Result.fail(ResultCode.RATE_LIMIT_EXCEEDED, "服务限流,请稍后重试");
}
// 3. 接口级限流检查
String path = httpStatement.getPath();
if (!checkInterfaceLimit(serviceName, path)) {
return Result.fail(ResultCode.RATE_LIMIT_EXCEEDED, "接口限流,请稍后重试");
}
// 4. IP级限流检查
String clientIp = getClientIp(request);
if (!checkIpLimit(clientIp)) {
return Result.fail(ResultCode.RATE_LIMIT_EXCEEDED, "IP限流,请稍后重试");
}
return Result.success();
}
private boolean checkGlobalLimit() {
RateLimitConfig config = configService.getConfig("GLOBAL", "GLOBAL");
if (config == null) {
return true;
}
return rateLimiter.tryAcquire("GLOBAL:GLOBAL", config);
}
private boolean checkServiceLimit(String serviceName) {
RateLimitConfig config = configService.getConfig("SERVICE", serviceName);
if (config == null) {
return true;
}
return rateLimiter.tryAcquire("SERVICE:" + serviceName, config);
}
private boolean checkInterfaceLimit(String serviceName, String path) {
String target = serviceName + ":" + path;
RateLimitConfig config = configService.getConfig("INTERFACE", target);
if (config == null) {
return true;
}
return rateLimiter.tryAcquire("INTERFACE:" + target, config);
}
private boolean checkIpLimit(String clientIp) {
RateLimitConfig config = configService.getConfig("IP", clientIp);
if (config == null) {
return true;
}
return rateLimiter.tryAcquire("IP:" + clientIp, config);
}
private String getClientIp(FullHttpRequest request) {
// 从请求头获取真实IP
String ip = request.headers().get("X-Real-IP");
if (ip == null || ip.isEmpty()) {
ip = request.headers().get("X-Forwarded-For");
}
return ip;
}
@Override
public int getOrder() {
return 10; // 在鉴权之后执行
}
@Override
public boolean canRunParallel() {
return false; // 串行执行,确保准确性
}
@Override
public boolean isFailFast() {
return true; // 限流失败时快速返回
}
}4.3 RateLimitConfigListener 完整实现
@Component
@Slf4j
public class RateLimitConfigListener implements MessageListener {
@Resource
private RedisTemplate<String, Object> redisTemplate;
@Resource
private DistributedRateLimiter rateLimiter;
@PostConstruct
public void init() {
log.info("初始化限流配置监听器");
// 启动时加载所有配置
loadAllRateLimitConfigs();
// 订阅配置更新消息
// 注:订阅逻辑在 RedisConfig 中配置
}
@Override
public void onMessage(Message message, byte[] pattern) {
String body = new String(message.getBody());
log.info("收到限流配置更新消息: {}", body);
try {
if ("RELOAD_ALL".equals(body)) {
// 全量重载
log.info("执行全量重载限流配置");
loadAllRateLimitConfigs();
} else {
// 增量更新
RateLimitConfig config = JSON.parseObject(body, RateLimitConfig.class);
String key = buildConfigKey(config);
rateLimiter.updateConfig(key, config);
log.info("更新限流配置成功: key={}", key);
}
} catch (Exception e) {
log.error("处理限流配置更新消息失败", e);
}
}
/**
* 加载所有限流配置
*/
private void loadAllRateLimitConfigs() {
try {
// 从 Redis 加载所有配置
Set<String> keys = redisTemplate.keys("rate_limit_config:*");
if (keys == null || keys.isEmpty()) {
log.info("未找到限流配置");
return;
}
for (String key : keys) {
Map<Object, Object> configMap = redisTemplate.opsForHash().entries(key);
RateLimitConfig config = convertToConfig(configMap);
String configKey = buildConfigKey(config);
rateLimiter.updateConfig(configKey, config);
}
log.info("加载限流配置完成,共{}条", keys.size());
} catch (Exception e) {
log.error("加载限流配置失败", e);
}
}
private String buildConfigKey(RateLimitConfig config) {
return config.getLimitType() + ":" + config.getLimitTarget();
}
private RateLimitConfig convertToConfig(Map<Object, Object> map) {
return RateLimitConfig.builder()
.id(Long.valueOf(map.get("id").toString()))
.ruleName(map.get("ruleName").toString())
.limitType(map.get("limitType").toString())
.limitTarget(map.get("limitTarget").toString())
.limitCount(Integer.valueOf(map.get("limitCount").toString()))
.timeWindow(Integer.valueOf(map.get("timeWindow").toString()))
.enabled(Boolean.valueOf(map.get("enabled").toString()))
.strategy(map.get("strategy").toString())
.build();
}
}