背景
实现平台化的mybatis能力,即在页面上输入mybatis的SQL模板,并传入参数,最终解析成可运行的SQL。
实现原理
引入依赖:
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis</artifactId>
<version>3.5.7</version>
</dependency>
mybatis的SQL生成器:
- 解析mybatis模板,生成预编译的SQL;
- 解析预编译SQL,参数替换?;
@Slf4j
public class MybatisGenerator {
private static final String HEAD = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
+ "<!DOCTYPE mapper PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\" \"http://mybatis"
+ ".org/dtd/mybatis-3-mapper.dtd\">"
+ "<mapper namespace=\"customGenerator\">"
+ "<select id=\"selectData\" parameterType=\"map\" resultType=\"map\">\n";
private static final String FOOT = "\n</select></mapper>";
private static final LoadingCache<String, MappedStatement> mappedStatementCache = CacheBuilder.newBuilder()
.refreshAfterWrite(1, TimeUnit.DAYS)
.build(new CacheLoader<String, MappedStatement>() {
@Override
public MappedStatement load(@NotNull String key) {
Configuration configuration = new Configuration();
configuration.setShrinkWhitespacesInSql(true);
String sourceSQL = HEAD + key + FOOT;
XMLMapperBuilder xmlMapperBuilder =
new XMLMapperBuilder(IOUtils.toInputStream(sourceSQL, Charset.forName("UTF-8")),
configuration, null,
null);
xmlMapperBuilder.parse();
return xmlMapperBuilder.getConfiguration().getMappedStatement("selectData");
}
});
//生成完整SQL
public static String generateDsl(SQLConfig apiConfig, Map<String, Object> conditions) {
String sql = apiConfig.getSqlTemplate();
try {
MappedStatement mappedStatement = mappedStatementCache.getUnchecked(sql);
BoundSql boundSql = mappedStatement.getBoundSql(conditions);
if (!boundSql.getParameterMappings().isEmpty()) {
List<PreparedStatementParameter> parameters = boundSql.getParameterMappings()
.stream().map(ParameterMapping::getProperty)
.map(param -> Optional.ofNullable(boundSql.getAdditionalParameter(param))
.orElseGet(() -> conditions.get(param)))
.map(PreparedStatementParameter::fromObject)
.collect(Collectors.toList());
//解析占位符,获取到完整SQL
return PreparedStatementParser.parse(boundSql.getSql()).buildSql(parameters);
} else {
return boundSql.getSql();
}
} catch (UncheckedExecutionException e) {
throw e;
}
}
@Data
public static class SQLConfig {
//SQL模板
private String sqlTemplate;
}
}
因为需要处理?(占位符),所以需要判断是否进行转义处理。
public final class ValueFormatter {
private static final Escaper ESCAPER = Escapers.builder()
.addEscape('\\', "\\\\")
.addEscape('\n', "\\n")
.addEscape('\t', "\\t")
.addEscape('\b', "\\b")
.addEscape('\f', "\\f")
.addEscape('\r', "\\r")
.addEscape('\u0000', "\\0")
.addEscape('\'', "\\'")
.addEscape('`', "\\`")
.build();
public static final String NULL_MARKER = "\\N";
private static final ThreadLocal<SimpleDateFormat> DATE_FORMAT =
ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd"));
private static final ThreadLocal<SimpleDateFormat> DATE_TIME_FORMAT =
ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
public static String formatBytes(byte[] bytes) {
if (bytes == null) {
return null;
} else {
char[] hexArray =
new char[] {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
char[] hexChars = new char[bytes.length * 4];
for (int j = 0; j < bytes.length; ++j) {
int v = bytes[j] & 255;
hexChars[j * 4] = '\\';
hexChars[j * 4 + 1] = 'x';
hexChars[j * 4 + 2] = hexArray[v / 16];
hexChars[j * 4 + 3] = hexArray[v % 16];
}
return new String(hexChars);
}
}
public static String formatInt(int myInt) {
return Integer.toString(myInt);
}
public static String formatDouble(double myDouble) {
return Double.toString(myDouble);
}
public static String formatChar(char myChar) {
return Character.toString(myChar);
}
public static String formatLong(long myLong) {
return Long.toString(myLong);
}
public static String formatFloat(float myFloat) {
return Float.toString(myFloat);
}
public static String formatBigDecimal(BigDecimal myBigDecimal) {
return myBigDecimal != null ? myBigDecimal.toPlainString() : "\\N";
}
public static String formatShort(short myShort) {
return Short.toString(myShort);
}
public static String formatString(String myString) {
return escape(myString);
}
public static String formatNull() {
return "\\N";
}
public static String formatByte(byte myByte) {
return Byte.toString(myByte);
}
public static String formatBoolean(boolean myBoolean) {
return myBoolean ? "1" : "0";
}
public static String formatUUID(UUID x) {
return x.toString();
}
public static String formatBigInteger(BigInteger x) {
return x.toString();
}
public static String formatObject(Object x) {
if (x == null) {
return null;
} else if (x instanceof Byte) {
return formatInt(((Byte) x).intValue());
} else if (x instanceof String) {
return formatString((String) x);
} else if (x instanceof BigDecimal) {
return formatBigDecimal((BigDecimal) x);
} else if (x instanceof Short) {
return formatShort((Short) x);
} else if (x instanceof Integer) {
return formatInt((Integer) x);
} else if (x instanceof Long) {
return formatLong((Long) x);
} else if (x instanceof Float) {
return formatFloat((Float) x);
} else if (x instanceof Double) {
return formatDouble((Double) x);
} else if (x instanceof byte[]) {
return formatBytes((byte[]) x);
} else if (x instanceof Boolean) {
return formatBoolean((Boolean) x);
} else if (x instanceof UUID) {
return formatUUID((UUID) x);
} else if (x instanceof BigInteger) {
return formatBigInteger((BigInteger) x);
} else {
return String.valueOf(x);
}
}
public static boolean needsQuoting(Object o) {
if (o == null) {
return false;
} else if (o instanceof Number) {
return false;
} else if (o instanceof Boolean) {
return false;
} else if (o.getClass().isArray()) {
return false;
} else {
return !(o instanceof Collection);
}
}
private static SimpleDateFormat getDateFormat() {
return DATE_FORMAT.get();
}
private static SimpleDateFormat getDateTimeFormat() {
return DATE_TIME_FORMAT.get();
}
public static String escape(String s) {
return s == null ? "\\N" : ESCAPER.escape(s);
}
public static String quoteIdentifier(String s) {
if (s == null) {
throw new IllegalArgumentException("Can't quote null as identifier");
} else {
StringBuilder sb = new StringBuilder(s.length() + 2);
sb.append('`');
sb.append(ESCAPER.escape(s));
sb.append('`');
return sb.toString();
}
}
}
定义预编译的参数:
public final class PreparedStatementParameter {
private static final PreparedStatementParameter
NULL_PARAM = new PreparedStatementParameter((String) null, false);
private static final PreparedStatementParameter
TRUE_PARAM = new PreparedStatementParameter("1", false);
private static final PreparedStatementParameter
FALSE_PARAM = new PreparedStatementParameter("0", false);
private final String stringValue;
private final boolean quoteNeeded;
//判断是否转义
public static PreparedStatementParameter fromObject(Object x) {
return x == null ? NULL_PARAM : new PreparedStatementParameter(
ValueFormatter.formatObject(x),
ValueFormatter.needsQuoting(x));
}
public static PreparedStatementParameter nullParameter() {
return NULL_PARAM;
}
public static PreparedStatementParameter boolParameter(boolean value) {
return value ? TRUE_PARAM : FALSE_PARAM;
}
public PreparedStatementParameter(String stringValue, boolean quoteNeeded) {
this.stringValue = stringValue == null ? "\\N" : stringValue;
this.quoteNeeded = quoteNeeded;
}
//判断是否需要转义
String getRegularValue() {
return !"\\N".equals(this.stringValue) ? (this.quoteNeeded ? "'" + this.stringValue + "'" : this.stringValue)
: "null";
}
String getBatchValue() {
return this.stringValue;
}
public String toString() {
return this.stringValue;
}
}
预编译解析器:将参数替换到占位符
public class PreparedStatementParser {
static final String PARAM_MARKER = "?";
static final String NULL_MARKER = "\\N";
private static final Pattern VALUES = Pattern.compile(
"(?i)INSERT\\s+INTO\\s+.+VALUES\\s*\\(",
Pattern.MULTILINE | Pattern.DOTALL);
private List<List<String>> parameters;
private List<String> parts;
private boolean valuesMode;
private PreparedStatementParser() {
parameters = new ArrayList<>();
parts = new ArrayList<>();
valuesMode = false;
}
public static PreparedStatementParser parse(String sql) {
return parse(sql, -1);
}
public static PreparedStatementParser parse(String sql, int valuesEndPosition) {
if (StringUtils.isBlank(sql)) {
throw new IllegalArgumentException("SQL may not be blank");
}
PreparedStatementParser parser = new PreparedStatementParser();
parser.parseSQL(sql, valuesEndPosition);
return parser;
}
List<List<String>> getParameters() {
return Collections.unmodifiableList(parameters);
}
List<String> getParts() {
return Collections.unmodifiableList(parts);
}
boolean isValuesMode() {
return valuesMode;
}
private void reset() {
parameters.clear();
parts.clear();
valuesMode = false;
}
private void parseSQL(String sql, int valuesEndPosition) {
reset();
List<String> currentParamList = new ArrayList<String>();
boolean afterBackSlash = false;
boolean inQuotes = false;
boolean inBackQuotes = false;
boolean inSingleLineComment = false;
boolean inMultiLineComment = false;
boolean whiteSpace = false;
int endPosition = 0;
if (valuesEndPosition > 0) {
valuesMode = true;
endPosition = valuesEndPosition;
} else {
Matcher matcher = VALUES.matcher(sql);
if (matcher.find()) {
valuesMode = true;
endPosition = matcher.end() - 1;
}
}
int currentParensLevel = 0;
int quotedStart = 0;
int partStart = 0;
int sqlLength = sql.length();
for (int i = valuesMode ? endPosition : 0, idxStart = i, idxEnd = i; i < sqlLength; i++) {
char c = sql.charAt(i);
if (inSingleLineComment) {
if (c == '\n') {
inSingleLineComment = false;
}
} else if (inMultiLineComment) {
if (c == '*' && sqlLength > i + 1 && sql.charAt(i + 1) == '/') {
inMultiLineComment = false;
i++;
}
} else if (afterBackSlash) {
afterBackSlash = false;
} else if (c == '\\') {
afterBackSlash = true;
} else if (c == '\'' && !inBackQuotes) {
inQuotes = !inQuotes;
if (inQuotes) {
quotedStart = i;
} else if (!afterBackSlash) {
idxStart = quotedStart;
idxEnd = i + 1;
}
} else if (c == '`' && !inQuotes) {
inBackQuotes = !inBackQuotes;
} else if (!inQuotes && !inBackQuotes) {
if (c == '?') {
if (currentParensLevel > 0) {
idxStart = i;
idxEnd = i + 1;
}
if (!valuesMode) {
parts.add(sql.substring(partStart, i));
partStart = i + 1;
currentParamList.add(PARAM_MARKER);
}
} else if (c == '-' && sqlLength > i + 1 && sql.charAt(i + 1) == '-') {
inSingleLineComment = true;
i++;
} else if (c == '/' && sqlLength > i + 1 && sql.charAt(i + 1) == '*') {
inMultiLineComment = true;
i++;
} else if (c == ',') {
if (valuesMode && idxEnd > idxStart) {
currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
parts.add(sql.substring(partStart, idxStart));
partStart = idxEnd;
idxEnd = i;
idxStart = idxEnd;
}
idxStart++;
idxEnd++;
} else if (c == '(') {
currentParensLevel++;
idxStart++;
idxEnd++;
} else if (c == ')') {
currentParensLevel--;
if (valuesMode && currentParensLevel == 0) {
if (idxEnd > idxStart) {
currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
parts.add(sql.substring(partStart, idxStart));
partStart = idxEnd;
idxEnd = i;
idxStart = idxEnd;
}
if (!currentParamList.isEmpty()) {
parameters.add(currentParamList);
currentParamList = new ArrayList<>(currentParamList.size());
}
}
} else if (Character.isWhitespace(c)) {
whiteSpace = true;
} else if (currentParensLevel > 0) {
if (whiteSpace) {
idxStart = i;
idxEnd = i + 1;
} else {
idxEnd++;
}
whiteSpace = false;
}
}
}
if (!valuesMode && !currentParamList.isEmpty()) {
parameters.add(currentParamList);
}
String lastPart = sql.substring(partStart, sqlLength);
parts.add(lastPart);
}
private static String typeTransformParameterValue(String paramValue) {
if (paramValue == null) {
return null;
}
if (Boolean.TRUE.toString().equalsIgnoreCase(paramValue)) {
return "1";
}
if (Boolean.FALSE.toString().equalsIgnoreCase(paramValue)) {
return "0";
}
if ("NULL".equalsIgnoreCase(paramValue)) {
return NULL_MARKER;
}
return paramValue;
}
public String buildSql(List<PreparedStatementParameter> binds) {
if (this.parts.size() == 1) {
return this.parts.get(0);
} else {
StringBuilder sb = new StringBuilder(this.parts.get(0));
int i = 1;
for (int t = 0; i < this.parts.size(); ++i) {
String pValue = this.getParameter(i - 1);
//占位符-#{}会进行转义
if ("?".equals(pValue)) {
sb.append(binds.get(t++).getRegularValue());
} else {
sb.append(pValue);
}
sb.append(this.parts.get(i));
}
return sb.toString();
}
}
private String getParameter(int paramIndex) {
int i = 0;
for (int count = paramIndex; i < this.parameters.size(); ++i) {
List<String> pList = this.parameters.get(i);
count = count - pList.size();
if (count < 0) {
return pList.get(pList.size() + count);
}
}
return null;
}
}