SpringBoot整合RateLimiter实现限流

Java
214
0
0
2024-02-27
标签   SpringBoot

写作缘由

在和某学长炫耀在自己会用Redis+Lua实现滑动窗口限流时,他说现在都用RateLimiter,所以就我就想搞个Demo,但是度娘了一下,感觉我搜索到的博客有几个个人认为不太完善的地方,比如只贴了部分代码,没贴依赖。尤其是你用AOP实现的时候,其实依赖哪个还有有讲究的;还有一个问题就是大多都是基于AOP实现,拦截器实现也是一个不错的方式,所以此处用拦截器HandlerInterceptorAdapter实现。

源码下载

https://github.com/cbeann/Demooo/tree/master/springboot-ratelimiter

部分代码

pom

 <!-- https://mvnrepository.com/artifact/com.google.guava/guava -->
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>28.1-jre</version>
        </dependency>

自定义注解ExtRateLimiter

package com.example.annotation;

import java.lang.annotation.*;

/**
 * @author chaird
 * @create 2021-03-20 17:57
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ExtRateLimiter {

  // 以每秒为单位固定的速率值往令牌桶中添加令牌
  double permitsPerSecond();

  // 在规定的毫秒数中,如果没有获取到令牌的话,则直接走服务器降级处理
  long timeout();
}

拦截器

package com.example.Interceptor;

import com.example.annotation.ExtRateLimiter;
import com.google.common.util.concurrent.RateLimiter;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * @author CBeann
 * @create 2020-07-04 18:06
 */
@Component
public class RateLimiterInceptor extends HandlerInterceptorAdapter {

  private Map<String, RateLimiter> rateHashMap = new ConcurrentHashMap<>();

  @Override
  public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
      throws Exception {

    if (!(handler instanceof HandlerMethod)) {
      return true;
    }

    final HandlerMethod handlerMethod = (HandlerMethod) handler;

    final Method method = handlerMethod.getMethod();
    // 有这个注解
    boolean methodAnn = method.isAnnotationPresent(ExtRateLimiter.class);
    if (methodAnn) {
      // 获取注解
      ExtRateLimiter extRateLimiter = method.getDeclaredAnnotation(ExtRateLimiter.class);
      //获取注解属性
      double permitsPerSecond = extRateLimiter.permitsPerSecond();
      long timeout = extRateLimiter.timeout();

      String key = method.getDeclaringClass().getName() + method.getName();

      RateLimiter rateLimiter = null;
      if (rateHashMap.get(key) == null) {
        rateLimiter = RateLimiter.create(permitsPerSecond);
        rateHashMap.put(key, rateLimiter);
      } else {
        rateLimiter = rateHashMap.get(key);
      }

      boolean tryAcquire = rateLimiter.tryAcquire(timeout, TimeUnit.MILLISECONDS);

      if (!tryAcquire) {
        response.setContentType("application/json; charset=utf-8");
        PrintWriter writer = response.getWriter();
        writer.print("限流");
        writer.close();
        response.flushBuffer();
        return false;
      }

      return super.preHandle(request, response, handler);
    }

    return true;
  }

  @Override
  public void postHandle(
      HttpServletRequest request,
      HttpServletResponse response,
      Object handler,
      ModelAndView modelAndView)
      throws Exception {

    super.postHandle(request, response, handler, modelAndView);
  }
}

拦截器配置

package com.example.config;

import com.example.Interceptor.RateLimiterInceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;

/**
 * @author chaird
 * @create 2020-09-23 16:13
 */
@Configuration
public class MVCConfig extends WebMvcConfigurerAdapter {

  @Autowired private RateLimiterInceptor rateLimiterInceptor;

  @Override
  public void addInterceptors(InterceptorRegistry registry) {
    // 获取http请求拦截器
    registry.addInterceptor(rateLimiterInceptor).addPathPatterns("/*");
  }
}

controller

package com.example.controller;

import com.example.annotation.ExtRateLimiter;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class StockController {

  @GetMapping("/getStock")
  @ExtRateLimiter(permitsPerSecond = 2, timeout = 500)
  public Object getStock() {
    String s = "ok";
    return s;
  }
}

测试

大约每秒允许2个请求

http://localhost:8080/getStock