背景:项目中提供接口给第三方平台使用。由于需要针对每个租户做请求并发控制;已知springcloud gateway整合了Redis 使用令牌桶算法做限流算法;参考
Spring Cloud Gateway实战案例(限流、熔断回退、跨域、统一异常处理和重试机制)-腾讯云开发者社区-腾讯云 (tencent.com)
核心算法对象RedisRateLimiter ,原本计划是重写这个对象,把自己的限流规则重新写入这个对象里面,交于springcloud 逻辑中,但是发现自己的重新对象一直无法注册到bean 容器中;就放弃了重新的方案,从而采用自己的的过滤器
package org.springframework.cloud.gateway.filter.ratelimit;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.validation.constraints.Min;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jetbrains.annotations.NotNull;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.beans.BeansException;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.data.redis.core.ReactiveRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.validation.Validator;
import org.springframework.validation.annotation.Validated;
/**
* See https://stripe.com/blog/rate-limiters and
* https://gist.github.com/ptarjan/e38f45f2dfe601419ca3af937fff574d#file-1-check_request_rate_limiter-rb-L11-L34
*
* @author Spencer Gibb
*/
@ConfigurationProperties("spring.cloud.gateway.redis-rate-limiter")
public class RedisRateLimiter extends AbstractRateLimiter<RedisRateLimiter.Config> implements ApplicationContextAware {
@Deprecated
public static final String REPLENISH_RATE_KEY = "replenishRate";
@Deprecated
public static final String BURST_CAPACITY_KEY = "burstCapacity";
public static final String CONFIGURATION_PROPERTY_NAME = "redis-rate-limiter";
public static final String REDIS_SCRIPT_NAME = "redisRequestRateLimiterScript";
public static final String REMAINING_HEADER = "X-RateLimit-Remaining";
public static final String REPLENISH_RATE_HEADER = "X-RateLimit-Replenish-Rate";
public static final String BURST_CAPACITY_HEADER = "X-RateLimit-Burst-Capacity";
private Log log = LogFactory.getLog(getClass());
private ReactiveRedisTemplate<String, String> redisTemplate;
private RedisScript<List<Long>> script;
private AtomicBoolean initialized = new AtomicBoolean(false);
private Config defaultConfig;
// configuration properties
/** Whether or not to include headers containing rate limiter information, defaults to true. */
private boolean includeHeaders = true;
/** The name of the header that returns number of remaining requests during the current second. */
private String remainingHeader = REMAINING_HEADER;
/** The name of the header that returns the replenish rate configuration. */
private String replenishRateHeader = REPLENISH_RATE_HEADER;
/** The name of the header that returns the burst capacity configuration. */
private String burstCapacityHeader = BURST_CAPACITY_HEADER;
public RedisRateLimiter(ReactiveRedisTemplate<String, String> redisTemplate,
RedisScript<List<Long>> script, Validator validator) {
super(Config.class, CONFIGURATION_PROPERTY_NAME, validator);
this.redisTemplate = redisTemplate;
this.script = script;
initialized.compareAndSet(false, true);
}
public RedisRateLimiter(int defaultReplenishRate, int defaultBurstCapacity) {
super(Config.class, CONFIGURATION_PROPERTY_NAME, null);
this.defaultConfig = new Config()
.setReplenishRate(defaultReplenishRate)
.setBurstCapacity(defaultBurstCapacity);
}
public boolean isIncludeHeaders() {
return includeHeaders;
}
public void setIncludeHeaders(boolean includeHeaders) {
this.includeHeaders = includeHeaders;
}
public String getRemainingHeader() {
return remainingHeader;
}
public void setRemainingHeader(String remainingHeader) {
this.remainingHeader = remainingHeader;
}
public String getReplenishRateHeader() {
return replenishRateHeader;
}
public void setReplenishRateHeader(String replenishRateHeader) {
this.replenishRateHeader = replenishRateHeader;
}
public String getBurstCapacityHeader() {
return burstCapacityHeader;
}
public void setBurstCapacityHeader(String burstCapacityHeader) {
this.burstCapacityHeader = burstCapacityHeader;
}
@Override
@SuppressWarnings("unchecked")
public void setApplicationContext(ApplicationContext context) throws BeansException {
if (initialized.compareAndSet(false, true)) {
this.redisTemplate = context.getBean("stringReactiveRedisTemplate", ReactiveRedisTemplate.class);
this.script = context.getBean(REDIS_SCRIPT_NAME, RedisScript.class);
if (context.getBeanNamesForType(Validator.class).length > 0) {
this.setValidator(context.getBean(Validator.class));
}
}
}
/* for testing */ Config getDefaultConfig() {
return defaultConfig;
}
/**
* This uses a basic token bucket algorithm and relies on the fact that Redis scripts
* execute atomically. No other operations can run between fetching the count and
* writing the new count.
*/
@Override
@SuppressWarnings("unchecked")
public Mono<Response> isAllowed(String routeId, String id) {
if (!this.initialized.get()) {
throw new IllegalStateException("RedisRateLimiter is not initialized");
}
Config routeConfig = getConfig().getOrDefault(routeId, defaultConfig);
if (routeConfig == null) {
throw new IllegalArgumentException("No Configuration found for route " + routeId);
}
// How many requests per second do you want a user to be allowed to do?
int replenishRate = routeConfig.getReplenishRate();
// How much bursting do you want to allow?
int burstCapacity = routeConfig.getBurstCapacity();
try {
List<String> keys = getKeys(id);
// The arguments to the LUA script. time() returns unixtime in seconds.
List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "",
Instant.now().getEpochSecond() + "", "1");
// allowed, tokens_left = redis.eval(SCRIPT, keys, args)
Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);
// .log("redisratelimiter", Level.FINER);
return flux.onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L)))
.reduce(new ArrayList<Long>(), (longs, l) -> {
longs.addAll(l);
return longs;
}) .map(results -> {
boolean allowed = results.get(0) == 1L;
Long tokensLeft = results.get(1);
Response response = new Response(allowed, getHeaders(routeConfig, tokensLeft));
if (log.isDebugEnabled()) {
log.debug("response: " + response);
}
return response;
});
}
catch (Exception e) {
/*
* We don't want a hard dependency on Redis to allow traffic. Make sure to set
* an alert so you know if this is happening too much. Stripe's observed
* failure rate is 0.01%.
*/
log.error("Error determining if user allowed from redis", e);
}
return Mono.just(new Response(true, getHeaders(routeConfig, -1L)));
}
@NotNull
public HashMap<String, String> getHeaders(Config config, Long tokensLeft) {
HashMap<String, String> headers = new HashMap<>();
headers.put(this.remainingHeader, tokensLeft.toString());
headers.put(this.replenishRateHeader, String.valueOf(config.getReplenishRate()));
headers.put(this.burstCapacityHeader, String.valueOf(config.getBurstCapacity()));
return headers;
}
static List<String> getKeys(String id) {
// use `{}` around keys to use Redis Key hash tags
// this allows for using redis cluster
// Make a unique key per user.
String prefix = "request_rate_limiter.{" + id;
// You need two Redis keys for Token Bucket.
String tokenKey = prefix + "}.tokens";
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
@Validated
public static class Config {
@Min(1)
private int replenishRate;
@Min(1)
private int burstCapacity = 1;
public int getReplenishRate() {
return replenishRate;
}
public Config setReplenishRate(int replenishRate) {
this.replenishRate = replenishRate;
return this;
}
public int getBurstCapacity() {
return burstCapacity;
}
public Config setBurstCapacity(int burstCapacity) {
this.burstCapacity = burstCapacity;
return this;
}
@Override
public String toString() {
return "Config{" +
"replenishRate=" + replenishRate +
", burstCapacity=" + burstCapacity +
'}';
}
}
}
单元测试 验证限流效果
public static void main(String[] args) {
RedisTemplate<String, String> template = new RedisTemplate<>();
RedisStandaloneConfiguration redisClusterConfiguration = new RedisStandaloneConfiguration();
redisClusterConfiguration.setHostName("10.0.124.60");
redisClusterConfiguration.setPort(6379);
LettuceConnectionFactory fac = new LettuceConnectionFactory(redisClusterConfiguration);
fac.afterPropertiesSet();
template.setConnectionFactory(fac);
template.setDefaultSerializer(new StringRedisSerializer());
template.afterPropertiesSet();
DefaultRedisScript redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/request_rate_limiter.lua")));
redisScript.setResultType(List.class);
int count = 0;
while (count < 10) {
String key = "key";//headerMap.get(Sign.CA_PROXY_SIGN_SECRET_KEY);
List<String> keys = getKeys(key);
// Object limitKey = redisTemplate.opsForHash().get("limit_key", key);
// JSONObject limitKeyJson = JSON.parseObject((String) limitKey);
String replenishRate = "3";//limitKeyJson.getString("replenishRate");
String burstCapacity = "2";//limitKeyJson.getString("burstCapacity");
// The arguments to the LUA script. time() returns unixtime in seconds.
List<String> scriptArgs = Arrays.asList();
// allowed, tokens_left = redis.eval(SCRIPT, keys, args)
List<Long> results = (List<Long>) template.execute(redisScript, keys, replenishRate + "", burstCapacity + "", Instant.now().getEpochSecond() + "", "1");
log.info("限流结果:{}", JSON.toJSONString(results));
// .log("redisratelimiter", Level.FINER);
boolean allowed = results.get(0) == 1L;
Long tokensLeft = results.get(1);
if(!allowed){
log.info("请求过于频繁,请稍后在请求");
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
};
count++;
}
}
lua 脚本
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
return { allowed_num, new_tokens }
RateLimiterFilter 过滤器
package com.jdh.opengateway.filter;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.jdh.opengateway.model.cif.RateLimitTenantRule;
import com.jdh.opengateway.model.cif.RateLimitUrlRule;
import com.jdh.opengateway.utils.HttpConstant;
import com.jdh.opengateway.utils.Sign;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.cloud.gateway.support.CachedBodyOutputMessage;
import org.springframework.cloud.gateway.support.DefaultServerRequest;
import org.springframework.core.Ordered;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import javax.annotation.PostConstruct;
import java.net.URI;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR;
@Component
@SuppressWarnings("All")
public class RateLimiterFilter implements GlobalFilter, Ordered {
private static final Logger log = LoggerFactory.getLogger(RateLimiterFilter.class);
static DefaultRedisScript<List> redisScript;
@Autowired
RedisTemplate redisTemplate;
@PostConstruct
public void redisRequestRateLimiterScript() {
redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/request_rate_limiter.lua")));
redisScript.setResultType(List.class);
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
String contentType = request.getHeaders().getFirst(HttpConstant.HTTP_HEADER_CONTENT_TYPE);
String method = request.getMethodValue();
String url = getUrl(exchange);
//判断是否为POST请求
if (null != contentType && HttpMethod.POST.name().equalsIgnoreCase(method)) {
ServerRequest serverRequest = new DefaultServerRequest(exchange);
// 读取请求体
Mono<String> modifiedBody = serverRequest.bodyToMono(String.class)
.flatMap(body -> {
try {
reqLimiter(body, url, request);
} catch (Exception e) {
return Mono.error(e);
}
return Mono.just(body);
});
return returnMono(exchange, chain, modifiedBody);
}
return chain.filter(exchange);
}
static List<String> getKeys(String id) {
// use `{}` around keys to use Redis Key hash tags
// this allows for using redis cluster
// Make a unique key per user.
String prefix = "request_rate_limiter.{" + id;
// You need two Redis keys for Token Bucket.
String tokenKey = prefix + "}.tokens";
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
public void reqLimiter(String body, String url, ServerHttpRequest request) {
Map<String, String> headerMap = request.getHeaders().toSingleValueMap();
JSONObject jsonObject = JSON.parseObject(body);
String key = headerMap.get(Sign.CA_PROXY_SIGN_SECRET_KEY);
Object limitKey = redisTemplate.opsForHash().get("ratelimitkey", key);
RateLimitTenantRule rateLimitTenantRule = JSONObject.parseObject((String) limitKey, RateLimitTenantRule.class);
HashMap<String, RateLimitUrlRule> urlRuleMap = rateLimitTenantRule.getUrlRuleMap();
RateLimitUrlRule rateLimitUrlRule = urlRuleMap.get(url);
List<String> keys = getKeys(key);
String replenishRate = "10";
String burstCapacity = "10";
if (rateLimitUrlRule != null) {
key = key + url;//Redis key精确到URL
keys = getKeys(key);
replenishRate = rateLimitUrlRule.getReplenishRate();
burstCapacity = rateLimitUrlRule.getBurstCapacity();
}
// The arguments to the LUA script. time() returns unixtime in seconds.
List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "", Instant.now().getEpochSecond() + "", "1");
// allowed, tokens_left = redis.eval(SCRIPT, keys, args)
List<Long> results = (List<Long>) redisTemplate.execute(redisScript, keys, scriptArgs);
log.info("限流结果:{}", JSON.toJSONString(results));
// .log("redisratelimiter", Level.FINER);
boolean allowed = results.get(0) == 1L;
Long tokensLeft = results.get(1);
Assert.isTrue(allowed, "请求过于频繁,请稍后在请求");
}
public static void main(String[] args) {
RedisTemplate<String, String> template = new RedisTemplate<>();
RedisStandaloneConfiguration redisClusterConfiguration = new RedisStandaloneConfiguration();
redisClusterConfiguration.setHostName("10.0.124.60");
redisClusterConfiguration.setPort(6379);
LettuceConnectionFactory fac = new LettuceConnectionFactory(redisClusterConfiguration);
fac.afterPropertiesSet();
template.setConnectionFactory(fac);
template.setDefaultSerializer(new StringRedisSerializer());
template.afterPropertiesSet();
DefaultRedisScript redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/request_rate_limiter.lua")));
redisScript.setResultType(List.class);
int count = 0;
while (count < 10) {
String key = "key";//headerMap.get(Sign.CA_PROXY_SIGN_SECRET_KEY);
List<String> keys = getKeys(key);
// Object limitKey = redisTemplate.opsForHash().get("limit_key", key);
// JSONObject limitKeyJson = JSON.parseObject((String) limitKey);
String replenishRate = "3";//limitKeyJson.getString("replenishRate");
String burstCapacity = "2";//limitKeyJson.getString("burstCapacity");
// The arguments to the LUA script. time() returns unixtime in seconds.
List<String> scriptArgs = Arrays.asList();
// allowed, tokens_left = redis.eval(SCRIPT, keys, args)
List<Long> results = (List<Long>) template.execute(redisScript, keys, replenishRate + "", burstCapacity + "", Instant.now().getEpochSecond() + "", "1");
log.info("限流结果:{}", JSON.toJSONString(results));
// .log("redisratelimiter", Level.FINER);
boolean allowed = results.get(0) == 1L;
Long tokensLeft = results.get(1);
if (!allowed) {
log.info("请求过于频繁,请稍后在请求");
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
;
count++;
}
}
/**
* 获取请求URL
*
* @param exchange
* @return
*/
private String getUrl(ServerWebExchange exchange) {
LinkedHashSet<URI> uris = exchange.getAttribute(GATEWAY_ORIGINAL_REQUEST_URL_ATTR);
AtomicReference<String> path = new AtomicReference<>("");
uris.forEach(uri -> {
path.set(uri.getPath());
});
return path.get();
}
/**
* 返回结果
*
* @param openReq
* @param exchange
* @param chain
* @param modifiedBody
* @param token
* @return
*/
private Mono<Void> returnMono(ServerWebExchange exchange, GatewayFilterChain chain, Mono<String> modifiedBody) {
BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, exchange.getRequest().getHeaders());
return bodyInserter.insert(outputMessage, new BodyInserterContext())
.then(Mono.defer(() -> {
ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public Flux<DataBuffer> getBody() {
return outputMessage.getBody();
}
@Override
public HttpHeaders getHeaders() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(super.getHeaders());
return httpHeaders;
}
};
return chain.filter(exchange.mutate().request(decorator).build());
}));
}
@Override
public int getOrder() {
return 3;
}
}
package com.jdh.opengateway.model.cif;
import lombok.Data;
import java.io.Serializable;
import java.util.HashMap;
@Data
public class RateLimitRule implements Serializable {
/**
* 多个租户限流规则集合
*/
private HashMap<String,RateLimitTenantRule> rateLimitTenantRuleMap;
}
package com.jdh.opengateway.model.cif;
import lombok.Data;
import java.io.Serializable;
import java.util.HashMap;
@Data
public class RateLimitTenantRule implements Serializable {
/**
* 每个接口 配置的限速
*/
private HashMap<String,RateLimitUrlRule> urlRuleMap;
}
package com.jdh.opengateway.model.cif;
import lombok.Data;
import java.io.Serializable;
@Data
public class RateLimitUrlRule implements Serializable {
/**
* 添加令牌速度 如10/秒
*/
private String replenishRate;
/**
* 桶令牌个数 入30
*/
private String burstCapacity;
}
每个租户存储的限流规则
{
"ratelimitkey": {
"租户1id": {
"请求url": {
"replenishRate": "令牌投递速度 几个每秒",
"burstCapacity": "令牌个数"
}
},
"租户2id": {
"请求url": {
"replenishRate": "令牌投递速度 几个每秒",
"burstCapacity": "令牌个数"
}
}
}
}