MyBatisPlus的SQL注入器
一、介绍
在前些天的时候,我需要写一个存在则更新的sql
语句,这以前我有记录过。
MySQL插入重复后进行覆盖更新 | 半月无霜 (banmoon.top)
但以前我是在mapping.xml
文件中,自己手动拼出来的。
虽然可以实现,但真的好麻烦,每个实体都要这样写吗?
我不,我去看了MyBatis plus
的BaseMapper
是如何实现的。
嘿,还真的让我找到了,不多说,上代码。
二、代码
在MP
中,有一个接口ISqlInjector.java
,它的一个实现类DefaultSqlInjector.java
,截图看看
可以看到,它自己弄了点方法注入进去了,所以我们只要依葫芦画瓢,也就能写出自己的方法;
1)编写方法
我们编写一个类似于Insert.java
的这样一个类,我们取名为InsertOnDuplicateKeyUpdateMethod.java
package com.banmoon.business.mybatis.method; | |
import com.baomidou.mybatisplus.annotation.IdType; | |
import com.baomidou.mybatisplus.core.injector.AbstractMethod; | |
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo; | |
import com.baomidou.mybatisplus.core.metadata.TableInfo; | |
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper; | |
import com.baomidou.mybatisplus.core.toolkit.StringUtils; | |
import com.baomidou.mybatisplus.core.toolkit.sql.SqlInjectionUtils; | |
import com.baomidou.mybatisplus.core.toolkit.sql.SqlScriptUtils; | |
import org.apache.ibatis.executor.keygen.Jdbc3KeyGenerator; | |
import org.apache.ibatis.executor.keygen.KeyGenerator; | |
import org.apache.ibatis.executor.keygen.NoKeyGenerator; | |
import org.apache.ibatis.mapping.MappedStatement; | |
import org.apache.ibatis.mapping.SqlSource; | |
import java.lang.reflect.Field; | |
import java.util.List; | |
import java.util.Optional; | |
import java.util.stream.Collectors; | |
public class InsertOnDuplicateKeyUpdateMethod extends AbstractMethod { | |
/** | |
* 方法名 | |
*/ | |
public static final String METHOD_NAME = "insertOnDuplicateKeyUpdate"; | |
/** | |
* 插入SQL模板 | |
*/ | |
public static final String SQL_TEMPLATE = "<script>\nINSERT INTO %s %s VALUES %s\n ON DUPLICATE KEY UPDATE\n %s\n</script>"; | |
/** | |
* 字段重复插入覆盖片段,使用旧值 | |
*/ | |
public static final String FIELD_FRAGMENT_OLD_VALUE = "\n%s = %s"; | |
/** | |
* 字段重复插入覆盖片段,使用新值 | |
*/ | |
public static final String FIELD_FRAGMENT_NEW_VALUE = "\n%s = VALUES(%s)"; | |
public InsertOnDuplicateKeyUpdateMethod() { | |
super(METHOD_NAME); | |
} | |
protected InsertOnDuplicateKeyUpdateMethod(String methodName) { | |
super(methodName); | |
} | |
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) { | |
KeyGenerator keyGenerator = NoKeyGenerator.INSTANCE; | |
String columnScript = SqlScriptUtils.convertTrim(tableInfo.getAllInsertSqlColumnMaybeIf(null), | |
LEFT_BRACKET, RIGHT_BRACKET, null, COMMA); | |
String valuesScript = SqlScriptUtils.convertTrim(tableInfo.getAllInsertSqlPropertyMaybeIf(null), | |
LEFT_BRACKET, RIGHT_BRACKET, null, COMMA); | |
String keyProperty = null; | |
String keyColumn = null; | |
// 表包含主键处理逻辑,如果不包含主键当普通字段处理 | |
if (StringUtils.isNotBlank(tableInfo.getKeyProperty())) { | |
if (tableInfo.getIdType() == IdType.AUTO) { | |
/* 自增主键 */ | |
keyGenerator = Jdbc3KeyGenerator.INSTANCE; | |
keyProperty = tableInfo.getKeyProperty(); | |
// 去除转义符 | |
keyColumn = SqlInjectionUtils.removeEscapeCharacter(tableInfo.getKeyColumn()); | |
} else if (null != tableInfo.getKeySequence()) { | |
keyGenerator = TableInfoHelper.genKeyGenerator(methodName, tableInfo, builderAssistant); | |
keyProperty = tableInfo.getKeyProperty(); | |
keyColumn = tableInfo.getKeyColumn(); | |
} | |
} | |
String duplicateKeyUpdateScript = generateDuplicateKeyUpdateScript(tableInfo); | |
String sql = String.format(SQL_TEMPLATE, tableInfo.getTableName(), columnScript, valuesScript, duplicateKeyUpdateScript); | |
SqlSource sqlSource = super.createSqlSource(configuration, sql, modelClass); | |
return this.addInsertMappedStatement(mapperClass, modelClass, methodName, sqlSource, keyGenerator, keyProperty, keyColumn); | |
} | |
protected String generateDuplicateKeyUpdateScript(TableInfo tableInfo) { | |
List<TableFieldInfo> fieldList = tableInfo.getFieldList(); | |
return fieldList.stream().map(f -> { | |
Field field = f.getField(); | |
String column = f.getColumn(); | |
OnDuplicateKeyUpdate onDuplicateKeyUpdate = field.getAnnotation(OnDuplicateKeyUpdate.class); | |
boolean newValue = Optional.ofNullable(onDuplicateKeyUpdate).map(OnDuplicateKeyUpdate::newValue).orElse(true); | |
if (newValue) { | |
return String.format(FIELD_FRAGMENT_NEW_VALUE, column, column); | |
} else { | |
return String.format(FIELD_FRAGMENT_OLD_VALUE, column, column); | |
} | |
}).collect(Collectors.joining(",")); | |
} | |
} |
大部分代码,都是和Insert.java
是一样的,我们主要是增强了ON DUPLICATE KEY UPDATE
后面的部分。
里面还有一个注解OnDuplicateKeyUpdate.java
,主要是判断重复导致更新时,是使用当前的值,还是使用插入的新值
package com.banmoon.business.mybatis.method; | |
import java.lang.annotation.ElementType; | |
import java.lang.annotation.Retention; | |
import java.lang.annotation.RetentionPolicy; | |
import java.lang.annotation.Target; | |
public OnDuplicateKeyUpdate { | |
/** | |
* 重复插入覆盖时,使用新值还是旧值 <br> | |
* 默认使用新值 | |
*/ | |
boolean newValue() default true; | |
} |
2)SqlInjector
上面有说到DefaultSqlInjector.java
,里面添加了许多方法,但是没有我们刚刚写的;
要把刚刚的方法加进去,直接继承它写一个自己的BanmoonSqlInjector.java
package com.banmoon.business.mybatis; | |
import com.banmoon.business.mybatis.method.InsertOnDuplicateKeyUpdateBatchMethod; | |
import com.banmoon.business.mybatis.method.InsertOnDuplicateKeyUpdateMethod; | |
import com.baomidou.mybatisplus.core.injector.AbstractMethod; | |
import com.baomidou.mybatisplus.core.injector.DefaultSqlInjector; | |
import com.baomidou.mybatisplus.core.metadata.TableInfo; | |
import java.util.List; | |
public class BanmoonSqlInjector extends DefaultSqlInjector { | |
public List<AbstractMethod> getMethodList(Class<?> mapperClass, TableInfo tableInfo) { | |
List<AbstractMethod> methodList = super.getMethodList(mapperClass, tableInfo); | |
// 添加自己的方法 | |
methodList.add(new InsertOnDuplicateKeyUpdateMethod()); | |
return methodList; | |
} | |
} |
好了,再将它放入Spring
容器中
package com.banmoon.business.config; | |
import com.banmoon.business.mybatis.BanmoonSqlInjector; | |
import org.springframework.context.annotation.Bean; | |
import org.springframework.context.annotation.Configuration; | |
public class MybatisEnhanceConfig { | |
public BanmoonSqlInjector banmoonSqlInjector() { | |
return new BanmoonSqlInjector(); | |
} | |
} |
3)BaseMapper
差点忘记了BaseMapper
,这里面可没有我们写的方法;所以同样的,直接继承,写一个自己的
package com.banmoon.business.mybatis; | |
import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |
public interface BanmoonMapper<T> extends BaseMapper<T> { | |
/** | |
* 插入一条记录 | |
* | |
* @param entity 实体对象 | |
*/ | |
int insertOnDuplicateKeyUpdate(T entity); | |
} |
三、测试
好了,上面的代码编写完毕,直接开始测试
package com.banmoon; | |
import com.banmoon.entity.UserEntity; | |
import com.banmoon.mapper.UserMapper; | |
import org.junit.Assert; | |
import org.junit.Test; | |
import org.junit.runner.RunWith; | |
import org.springframework.boot.test.context.SpringBootTest; | |
import org.springframework.test.context.junit4.SpringRunner; | |
import javax.annotation.Resource; | |
public class ServerTest { | |
private UserMapper userMapper; | |
public void test() { | |
UserEntity userEntity = new UserEntity(); | |
userEntity.setId(5); | |
userEntity.setUsername("测试"); | |
userEntity.setPassword("1234"); | |
userEntity.setStatus(1); | |
int i = userMapper.insertOnDuplicateKeyUpdate(userEntity); | |
Assert.assertEquals(i, 1); | |
userEntity.setUsername("测试覆盖"); | |
userMapper.insertOnDuplicateKeyUpdate(userEntity); | |
UserEntity entity = userMapper.selectById(5); | |
Assert.assertEquals("测试覆盖", entity.getUsername()); | |
} | |
} |
查看日志打印的信息,可以看到后面的字段已经贴上了
查看数据库最后的结果
四、最后
还差一个批量插入重复覆盖的,这个后面补上。
我是半月,你我一同共勉!!!