分布式限流器 - 算法实现与代码
2025/11/13大约 10 分钟
分布式限流器 - 算法实现与代码
一、限流算法详解
3.1 令牌桶算法 (Token Bucket)
算法原理
令牌桶算法是一种常用的限流算法,核心思想是:
- 令牌生成:系统以恒定速率向桶中放入令牌
- 令牌消费:每个请求需要消耗一个令牌
- 令牌不足:桶中没有令牌时,请求被拒绝
- 桶有容量:令牌数量有上限,多余的令牌会被丢弃
时间轴: t0 t1 t2 t3 t4
↓ ↓ ↓ ↓ ↓
令牌桶: [●●●] → [●●●●] → [●●●] → [●●●●] → [●●●]
↑生成 ↑请求 ↑生成 ↑请求 ↑生成
令牌 消耗 令牌 消耗 令牌算法特点
优点:
- 允许突发流量:桶中积累的令牌可以应对短时间的流量突发
- 平滑限流:长期来看,流量被平滑到恒定速率
- 实现简单:逻辑清晰,易于理解和实现
缺点:
- 突发流量可能过大:如果桶容量设置过大,可能导致瞬时流量过高
适用场景:
- 全局限流:保护网关整体,允许一定突发
- 服务级限流:保护后端服务,提高资源利用率
- IP级限流:防止恶意攻击,允许正常用户的突发请求
Redis + Lua 实现
-- 令牌桶算法 Lua 脚本
-- KEYS[1]: 限流 key
-- ARGV[1]: 限流阈值(每秒令牌数)
local key = KEYS[1]
local limit = tonumber(ARGV[1])
-- 获取当前计数
local current = tonumber(redis.call('get', key) or '0')
if current < limit then
-- 令牌充足,计数+1
redis.call('incr', key)
-- 首次计数时设置过期时间
if current == 0 then
redis.call('expire', key, 1)
end
return 1 -- 允许通过
else
return 0 -- 拒绝请求
end实现细节:
- 计数器重置:使用 Redis 的
expire命令,1秒后自动重置计数器 - 原子性保证:Lua 脚本在 Redis 中原子执行,避免并发问题
- 简化实现:不维护真实的令牌桶,而是用计数器模拟
Java 调用代码:
public boolean tryAcquireWithTokenBucket(String key, RateLimitConfig config) {
String script =
"local key = KEYS[1]\n" +
"local limit = tonumber(ARGV[1])\n" +
"local current = tonumber(redis.call('get', key) or '0')\n" +
"if current < limit then\n" +
" redis.call('incr', key)\n" +
" if current == 0 then\n" +
" redis.call('expire', key, 1)\n" +
" end\n" +
" return 1\n" +
"else\n" +
" return 0\n" +
"end";
Long result = redisTemplate.execute(
new DefaultRedisScript<>(script, Long.class),
Collections.singletonList(key),
config.getLimitCount().toString()
);
return result != null && result == 1;
}3.2 滑动窗口算法 (Sliding Window)
算法原理
滑动窗口算法将时间划分为多个窗口,统计窗口内的请求数:
- 时间窗口:定义一个时间窗口(如1秒)
- 请求记录:记录每个请求的时间戳
- 窗口滑动:随着时间推移,窗口向前滑动
- 计数判断:统计当前窗口内的请求数,超过阈值则拒绝
时间轴: |----1s----|----1s----|----1s----|
请求: ●●● ●●●● ●●
↑ ↑ ↑
窗口1 窗口2 窗口3
(3个) (4个) (2个)滑动窗口示例:
当前时间:1000ms
窗口大小:1000ms
窗口范围:[0ms, 1000ms]
请求时间戳:100ms, 200ms, 300ms, 1100ms
窗口内请求:100ms, 200ms, 300ms (3个)
窗口外请求:1100ms (被移除)算法特点
优点:
- 精确控制:严格限制时间窗口内的请求数
- 防止突发:不允许超过阈值的突发流量
- 实时性好:窗口实时滑动,响应及时
缺点:
- 存储开销:需要记录每个请求的时间戳
- 计算复杂:需要清理过期数据,计算窗口内请求数
适用场景:
- 接口级限流:保护核心接口(如登录、支付),需要精确控制
- 防止暴力破解:严格限制登录尝试次数
- 防止爬虫:精确控制单个IP的请求频率
Redis + Lua 实现
-- 滑动窗口算法 Lua 脚本
-- KEYS[1]: 限流 key
-- ARGV[1]: 限流阈值
-- ARGV[2]: 时间窗口(秒)
-- ARGV[3]: 当前时间戳(毫秒)
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = tonumber(ARGV[3])
-- 计算窗口起始时间
local expire_time = current - window * 1000
-- 删除窗口外的过期数据
redis.call('zremrangebyscore', key, 0, expire_time)
-- 统计当前窗口内的请求数
local count = redis.call('zcard', key)
if count < limit then
-- 未达到限流阈值,添加当前请求
redis.call('zadd', key, current, current)
-- 设置过期时间(窗口大小 + 1秒)
redis.call('expire', key, window + 1)
return 1 -- 允许通过
else
return 0 -- 拒绝请求
end实现细节:
- 数据结构:使用 Redis 的 Sorted Set,score 为时间戳
- 过期清理:使用
zremrangebyscore删除窗口外的数据 - 计数统计:使用
zcard统计窗口内的请求数 - 自动过期:设置 key 的过期时间,避免内存泄漏
Java 调用代码:
public boolean tryAcquireWithSlidingWindow(String key, RateLimitConfig config) {
String script =
"local key = KEYS[1]\n" +
"local limit = tonumber(ARGV[1])\n" +
"local window = tonumber(ARGV[2])\n" +
"local current = tonumber(ARGV[3])\n" +
"local expire_time = current - window * 1000\n" +
"redis.call('zremrangebyscore', key, 0, expire_time)\n" +
"local count = redis.call('zcard', key)\n" +
"if count < limit then\n" +
" redis.call('zadd', key, current, current)\n" +
" redis.call('expire', key, window + 1)\n" +
" return 1\n" +
"else\n" +
" return 0\n" +
"end";
Long result = redisTemplate.execute(
new DefaultRedisScript<>(script, Long.class),
Collections.singletonList(key),
config.getLimitCount().toString(),
config.getTimeWindow().toString(),
String.valueOf(System.currentTimeMillis())
);
return result != null && result == 1;
}3.3 算法选择策略
| 场景 | 推荐算法 | 理由 |
|---|---|---|
| 全局限流 | 令牌桶 | 允许突发流量,提高资源利用率,避免过度限流 |
| 服务级限流 | 令牌桶 | 保护后端服务,允许短时突发,提升用户体验 |
| 接口级限流 | 滑动窗口 | 精确控制核心接口,防止接口被打垮 |
| IP级限流 | 令牌桶 | 防止恶意攻击,允许正常用户的突发请求 |
| 登录接口 | 滑动窗口 | 防止暴力破解,严格限制尝试次数 |
| 支付接口 | 滑动窗口 | 保护资金安全,精确控制请求频率 |
选择原则:
- 需要允许突发流量 → 令牌桶
- 需要精确控制 → 滑动窗口
- 性能要求高 → 令牌桶(实现更简单,性能更好)
- 安全要求高 → 滑动窗口(控制更严格)
二、核心代码实现
4.1 DistributedRateLimiter 完整实现
@Component
@Slf4j
public class DistributedRateLimiter {
@Resource
private RedisTemplate<String, Object> redisTemplate;
// 本地令牌计数器(LOCAL_DISTRIBUTED 模式)
private final Map<String, AtomicInteger> localTokenCounters = new ConcurrentHashMap<>();
// 配置缓存
private final Map<String, RateLimitConfig> configCache = new ConcurrentHashMap<>();
/**
* 尝试获取令牌
* @param key 限流 key
* @param config 限流配置
* @return true-允许通过,false-拒绝请求
*/
public boolean tryAcquire(String key, RateLimitConfig config) {
if (config == null || !config.getEnabled()) {
return true;
}
try {
String mode = config.getMode() != null ? config.getMode() : "DISTRIBUTED";
if ("LOCAL_DISTRIBUTED".equals(mode)) {
// 本地+分布式混合模式
return tryAcquireLocalDistributed(key, config);
} else {
// 默认分布式模式
return tryAcquireDistributed(key, config);
}
} catch (Exception e) {
log.error("限流异常,降级为放行: key={}", key, e);
return true; // 异常时放行,保证服务可用性
}
}
/**
* 分布式限流(默认模式)
* 直接使用 Redis 侧的滑动窗口或令牌桶算法进行限流
*/
private boolean tryAcquireDistributed(String key, RateLimitConfig config) {
String redisKey = "rate_limit:" + key;
if ("TOKEN_BUCKET".equals(config.getStrategy())) {
return tryAcquireTokenBucket(redisKey, config);
} else {
return tryAcquireSlidingWindow(redisKey, config);
}
}
/**
* 本地+分布式混合模式
* 使用令牌桶算法从 Redis 批量获取令牌,然后在本地进行高性能限流
*/
private boolean tryAcquireLocalDistributed(String key, RateLimitConfig config) {
// 获取或创建本地令牌计数器
AtomicInteger tokenCounter = localTokenCounters.computeIfAbsent(key, k -> new AtomicInteger(0));
// 尝试消费本地令牌
int currentTokens = tokenCounter.get();
if (currentTokens > 0) {
// 本地有令牌,直接消费
if (tokenCounter.compareAndSet(currentTokens, currentTokens - 1)) {
return true;
}
// CAS 失败,重试
return tryAcquireLocalDistributed(key, config);
}
// 本地令牌不足,从 Redis 批量获取
Integer batchSize = config.getLocalBatchSize() != null ? config.getLocalBatchSize() : 100;
int acquiredTokens = batchGetTokensFromRedis(key, config, batchSize);
if (acquiredTokens > 0) {
// 成功获取令牌,设置本地计数器(减1是因为当前请求消费一个)
tokenCounter.set(acquiredTokens - 1);
return true;
}
// 无法获取令牌,限流
return false;
}
/**
* 从 Redis 批量获取令牌
*/
private int batchGetTokensFromRedis(String key, RateLimitConfig config, int batchSize) {
try {
String redisKey = "rate_limit:" + key;
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setScriptText(BATCH_GET_TOKENS_SCRIPT);
script.setResultType(Long.class);
Long result = redisTemplate.execute(
script,
Collections.singletonList(redisKey),
batchSize,
config.getLimitCount()
);
return result != null ? result.intValue() : 0;
} catch (Exception e) {
log.error("从 Redis 批量获取令牌失败: key={}", key, e);
return 0;
}
}
/**
* 令牌桶算法
*/
private boolean tryAcquireTokenBucket(String redisKey, RateLimitConfig config) {
try {
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setScriptText(TOKEN_BUCKET_SCRIPT);
script.setResultType(Long.class);
Long result = redisTemplate.execute(
script,
Collections.singletonList(redisKey),
config.getLimitCount()
);
return result == 1L;
} catch (Exception e) {
log.error("令牌桶限流执行失败: {}", redisKey, e);
return true; // 异常时放行
}
}
/**
* 滑动窗口算法
*/
private boolean tryAcquireSlidingWindow(String redisKey, RateLimitConfig config) {
try {
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setScriptText(SLIDING_WINDOW_SCRIPT);
script.setResultType(Long.class);
long currentTime = System.currentTimeMillis();
Long result = redisTemplate.execute(
script,
Collections.singletonList(redisKey),
config.getLimitCount(),
config.getTimeWindow(),
currentTime
);
return result == 1L;
} catch (Exception e) {
log.error("滑动窗口限流执行失败: {}", redisKey, e);
return true; // 异常时放行
}
}
/**
* 更新配置
*/
public void updateConfig(String key, RateLimitConfig config) {
log.info("更新限流配置: key={}, mode={}, config={}", key, config.getMode(), config);
configCache.put(key, config);
localTokenCounters.remove(key); // 清除旧的令牌计数器
}
/**
* 获取配置
*/
public RateLimitConfig getConfig(String key) {
return configCache.get(key);
}
/**
* 移除配置
*/
public void removeConfig(String key) {
log.info("移除限流配置: key={}", key);
configCache.remove(key);
localTokenCounters.remove(key);
}
/**
* 清空所有限流配置
*/
public void clearAllConfigs() {
log.info("清空所有限流配置");
configCache.clear();
localTokenCounters.clear();
}
}4.2 批量获取令牌的 Lua 脚本
-- 批量获取令牌(用于 LOCAL_DISTRIBUTED 模式)
-- KEYS[1]: 限流 key
-- ARGV[1]: 批量大小
-- ARGV[2]: 限流阈值
local key = KEYS[1]
local batch_size = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local current = tonumber(redis.call('get', key) or '0')
-- 计算可用令牌数(不超过批量大小,也不超过剩余容量)
local available = math.min(batch_size, limit - current)
if available > 0 then
-- 增加计数
redis.call('incrby', key, available)
-- 首次计数时设置过期时间
if current == 0 then
redis.call('expire', key, 1)
end
return available
else
return 0
end4.3 RateLimitHandler 完整实现
@Slf4j
@Component
@ChannelHandler.Sharable
public class RateLimitHandler extends BaseHandler<FullHttpRequest> {
private static final AttributeKey<HttpStatement> HTTP_STATEMENT_KEY = AttributeKey.valueOf("HttpStatement");
@Resource
private DistributedRateLimiter rateLimiter;
@Override
protected void handle(ChannelHandlerContext ctx, Channel channel, FullHttpRequest request) {
HttpStatement httpStatement = channel.attr(HTTP_STATEMENT_KEY).get();
if (httpStatement == null) {
sendError(channel, "系统处理异常");
return;
}
try {
// 1. 全局限流检查
if (!checkGlobalRateLimit()) {
sendError(channel, "系统繁忙,请稍后重试");
return;
}
// 2. 服务级限流检查
String serviceId = httpStatement.getServiceId();
if (!checkServiceRateLimit(serviceId)) {
sendError(channel, "服务访问频繁,请稍后重试");
return;
}
// 3. 接口级限流检查
String url = RequestParameterUtil.getUrl(request);
if (!checkInterfaceRateLimit(serviceId, url)) {
sendError(channel, "接口访问频繁,请稍后重试");
return;
}
// 4. IP级限流检查
String clientIp = getClientIp(request);
if (!checkIpRateLimit(clientIp)) {
sendError(channel, "访问过于频繁,请稍后重试");
return;
}
// 所有限流检查通过,继续处理链
ctx.fireChannelRead(request);
} catch (Exception e) {
log.error("限流处理异常", e);
// 异常时放行,避免影响正常业务
ctx.fireChannelRead(request);
}
}
private boolean checkGlobalRateLimit() {
String key = "GLOBAL";
RateLimitConfig config = rateLimiter.getConfig(key);
if (config == null || !config.getEnabled()) {
return true;
}
return rateLimiter.tryAcquire(key, config);
}
private boolean checkServiceRateLimit(String serviceId) {
String key = "SERVICE:" + serviceId;
RateLimitConfig config = rateLimiter.getConfig(key);
if (config == null || !config.getEnabled()) {
return true;
}
return rateLimiter.tryAcquire(key, config);
}
private boolean checkInterfaceRateLimit(String serviceId, String url) {
String key = "INTERFACE:" + serviceId + ":" + url;
RateLimitConfig config = rateLimiter.getConfig(key);
if (config == null || !config.getEnabled()) {
return true;
}
return rateLimiter.tryAcquire(key, config);
}
private boolean checkIpRateLimit(String clientIp) {
String key = "IP:" + clientIp;
RateLimitConfig config = rateLimiter.getConfig(key);
if (config == null || !config.getEnabled()) {
return true;
}
return rateLimiter.tryAcquire(key, config);
}
private String getClientIp(FullHttpRequest request) {
// 尝试从X-Forwarded-For获取真实IP
String xff = request.headers().get("X-Forwarded-For");
if (xff != null && !xff.isEmpty()) {
return xff.split(",")[0].trim();
}
// 尝试从X-Real-IP获取
String realIp = request.headers().get("X-Real-IP");
if (realIp != null && !realIp.isEmpty()) {
return realIp;
}
// 默认返回unknown
return "unknown";
}
private void sendError(Channel channel, String message) {
channel.writeAndFlush(RequestResultUtil.parse(Result.error(message)));
}
}4.3 RateLimitConfigListener 完整实现
@Component
@Slf4j
public class RateLimitConfigListener implements MessageListener {
@Resource
private RedisTemplate<String, Object> redisTemplate;
@Resource
private DistributedRateLimiter rateLimiter;
@PostConstruct
public void init() {
log.info("初始化限流配置监听器");
// 启动时加载所有配置
loadAllRateLimitConfigs();
// 订阅配置更新消息
// 注:订阅逻辑在 RedisConfig 中配置
}
@Override
public void onMessage(Message message, byte[] pattern) {
String body = new String(message.getBody());
log.info("收到限流配置更新消息: {}", body);
try {
if ("RELOAD_ALL".equals(body)) {
// 全量重载
log.info("执行全量重载限流配置");
loadAllRateLimitConfigs();
} else {
// 增量更新
RateLimitConfig config = JSON.parseObject(body, RateLimitConfig.class);
String key = buildConfigKey(config);
rateLimiter.updateConfig(key, config);
log.info("更新限流配置成功: key={}", key);
}
} catch (Exception e) {
log.error("处理限流配置更新消息失败", e);
}
}
/**
* 加载所有限流配置
*/
private void loadAllRateLimitConfigs() {
try {
// 从 Redis 加载所有配置
Set<String> keys = redisTemplate.keys("rate_limit_config:*");
if (keys == null || keys.isEmpty()) {
log.info("未找到限流配置");
return;
}
for (String key : keys) {
Map<Object, Object> configMap = redisTemplate.opsForHash().entries(key);
RateLimitConfig config = convertToConfig(configMap);
String configKey = buildConfigKey(config);
rateLimiter.updateConfig(configKey, config);
}
log.info("加载限流配置完成,共{}条", keys.size());
} catch (Exception e) {
log.error("加载限流配置失败", e);
}
}
private String buildConfigKey(RateLimitConfig config) {
return config.getLimitType() + ":" + config.getLimitTarget();
}
private RateLimitConfig convertToConfig(Map<Object, Object> map) {
return RateLimitConfig.builder()
.id(Long.valueOf(map.get("id").toString()))
.ruleName(map.get("ruleName").toString())
.limitType(map.get("limitType").toString())
.limitTarget(map.get("limitTarget").toString())
.limitCount(Integer.valueOf(map.get("limitCount").toString()))
.timeWindow(Integer.valueOf(map.get("timeWindow").toString()))
.enabled(Boolean.valueOf(map.get("enabled").toString()))
.strategy(map.get("strategy").toString())
.build();
}
}