写作缘由
在和某学长炫耀在自己会用Redis+Lua实现滑动窗口限流时,他说现在都用RateLimiter,所以就我就想搞个Demo,但是度娘了一下,感觉我搜索到的博客有几个个人认为不太完善的地方,比如只贴了部分代码,没贴依赖。尤其是你用AOP实现的时候,其实依赖哪个还有有讲究的;还有一个问题就是大多都是基于AOP实现,拦截器实现也是一个不错的方式,所以此处用拦截器HandlerInterceptorAdapter实现。
源码下载
https://github.com/cbeann/Demooo/tree/master/springboot-ratelimiter
部分代码
pom
<!-- https://mvnrepository.com/artifact/com.google.guava/guava -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>28.1-jre</version>
</dependency>
自定义注解ExtRateLimiter
package com.example.annotation;
import java.lang.annotation.*;
/**
* @author chaird
* @create 2021-03-20 17:57
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ExtRateLimiter {
// 以每秒为单位固定的速率值往令牌桶中添加令牌
double permitsPerSecond();
// 在规定的毫秒数中,如果没有获取到令牌的话,则直接走服务器降级处理
long timeout();
}
拦截器
package com.example.Interceptor;
import com.example.annotation.ExtRateLimiter;
import com.google.common.util.concurrent.RateLimiter;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
/**
* @author CBeann
* @create 2020-07-04 18:06
*/
@Component
public class RateLimiterInceptor extends HandlerInterceptorAdapter {
private Map<String, RateLimiter> rateHashMap = new ConcurrentHashMap<>();
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
throws Exception {
if (!(handler instanceof HandlerMethod)) {
return true;
}
final HandlerMethod handlerMethod = (HandlerMethod) handler;
final Method method = handlerMethod.getMethod();
// 有这个注解
boolean methodAnn = method.isAnnotationPresent(ExtRateLimiter.class);
if (methodAnn) {
// 获取注解
ExtRateLimiter extRateLimiter = method.getDeclaredAnnotation(ExtRateLimiter.class);
//获取注解属性
double permitsPerSecond = extRateLimiter.permitsPerSecond();
long timeout = extRateLimiter.timeout();
String key = method.getDeclaringClass().getName() + method.getName();
RateLimiter rateLimiter = null;
if (rateHashMap.get(key) == null) {
rateLimiter = RateLimiter.create(permitsPerSecond);
rateHashMap.put(key, rateLimiter);
} else {
rateLimiter = rateHashMap.get(key);
}
boolean tryAcquire = rateLimiter.tryAcquire(timeout, TimeUnit.MILLISECONDS);
if (!tryAcquire) {
response.setContentType("application/json; charset=utf-8");
PrintWriter writer = response.getWriter();
writer.print("限流");
writer.close();
response.flushBuffer();
return false;
}
return super.preHandle(request, response, handler);
}
return true;
}
@Override
public void postHandle(
HttpServletRequest request,
HttpServletResponse response,
Object handler,
ModelAndView modelAndView)
throws Exception {
super.postHandle(request, response, handler, modelAndView);
}
}
拦截器配置
package com.example.config;
import com.example.Interceptor.RateLimiterInceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
/**
* @author chaird
* @create 2020-09-23 16:13
*/
@Configuration
public class MVCConfig extends WebMvcConfigurerAdapter {
@Autowired private RateLimiterInceptor rateLimiterInceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
// 获取http请求拦截器
registry.addInterceptor(rateLimiterInceptor).addPathPatterns("/*");
}
}
controller
package com.example.controller;
import com.example.annotation.ExtRateLimiter;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
public class StockController {
@GetMapping("/getStock")
@ExtRateLimiter(permitsPerSecond = 2, timeout = 500)
public Object getStock() {
String s = "ok";
return s;
}
}
测试
大约每秒允许2个请求
http://localhost:8080/getStock