蹲厕所的熊

benjaminwhx

Mybatis Plugin原理分析

2018-05-26 作者: 吴海旭


先给出一个配置,mybatis-config.xml和spring中的配置

<!-- mybatis configuration配置文件中配置 -->
<plugins>
    <plugin interceptor="com.xx.xx.SqlMonitorInterceptor">
        <property name="slowSqlTimeout" value="2000"/>
    </plugin>
</plugins>

<!-- spring中的配置 -->
<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
    <property name="plugins">
        <array>
            <ref bean="monitorInterceptor" />
        </array>
    </property>
</bean>

<bean id="monitorInterceptor" class="com.xx.xx.SqlMonitorInterceptor">
    <property name="slowSqlTimeout" value="2000"/>
</bean>

配置完成后,Configuration会进行解析,并加入配置的拦截器。

// 初始化拦截链
protected final InterceptorChain interceptorChain = new InterceptorChain();

// 解析xml plugin节点
private void pluginElement(XNode parent) throws Exception {
  if (parent != null) {
    for (XNode child : parent.getChildren()) {
        String interceptor = child.getStringAttribute("interceptor");
        Properties properties = child.getChildrenAsProperties();
        Interceptor interceptorInstance = (Interceptor) resolveClass(interceptor).newInstance();
        interceptorInstance.setProperties(properties);
        configuration.addInterceptor(interceptorInstance);
    }
  }
}

// 把拦截器放入拦截链中
public void addInterceptor(Interceptor interceptor) {
    interceptorChain.addInterceptor(interceptor);
}

Configuration类中有几个方法通过 interceptorChain.pluginAll() 调用了自定义拦截器的plugin方法来生成代理对象,源码如下:

public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
    ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
    parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
    return parameterHandler;
}

public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,ResultHandler resultHandler, BoundSql boundSql) {
    ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
    resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
    return resultSetHandler;
}

public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
    StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
    statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
    return statementHandler;
}

public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
    ...
    executor = (Executor) interceptorChain.pluginAll(executor);
    return executor;
}

我们再来看看拦截链的源码,实际上mybatis的四大对象StatementHandler、ParameterHandler、ResultHandler和Executor都调用了pluginAll方法,并且都会进入你配置的拦截器的plugin方法:

public class InterceptorChain {

  private final List<Interceptor> interceptors = new ArrayList<Interceptor>();

  /**
   * 对配置的plugin调用plugin方法
   */
  public Object pluginAll(Object target) {
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }

  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }

  public List<Interceptor> getInterceptors() {
    return Collections.unmodifiableList(interceptors);
  }

}

我们自己定义的拦截器的plugin方法实际上调用了Plugin的wrap方法,我们看看wrap的源码:

public static Object wrap(Object target, Interceptor interceptor) {
    // 获取自定义拦截器中定义的@Intercepts中配置的类和方法的map
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    // mybatis四大对象访问的类型
    Class<?> type = target.getClass();
    // 签名map中是否包含这4大对象的接口
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    if (interfaces.length > 0) {
      // 如果存在,返回target的动态代理类
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    // 返回原对象
    return target;
  }

如果plugin返回的是target的代理类,那么后面执行target的方法时,就会进入Plugin的invoke方法中(不懂的请去了解jdk动态代理的相关知识)

@Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      // 获取自定义拦截器中指定的方法,只有这些方法才会进行拦截
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      if (methods != null && methods.contains(method)) {
        // 走入自定义拦截器的intercept方法
        return interceptor.intercept(new Invocation(target, method, args));
      }
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }

既然走到自定义拦截器的intercept方法了,那么我们可以自己定义一些操作来拦截sql的执行,比如实现分页、分库分表的功能。下面我给出一个监控慢sql和连接池过高报警的plugin类:

@Intercepts({
        @Signature(
                type = Executor.class,
                method = "query",
                args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(
                type = Executor.class,
                method = "query",
                args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
        @Signature(
                type = Executor.class,
                method = "update",
                args = {MappedStatement.class, Object.class})
})
public class MonitorInterceptor implements Interceptor {

    private Logger logger = LoggerFactory.getLogger(MonitorInterceptor.class);

    /**
     * SQL执行异常报警开关。
     */
    private boolean sqlExceptionEnabled = true;

    /**
     * 慢SQL报警开关。
     */
    private boolean slowSqlEnabled = true;

    /**
     * 最小的慢SQL超时时间,不能小于这个值,以免报警过于频繁。
     */
    public static final long MIN_SLOW_SQL_TIMEOUT = 20;

    /**
     * 慢SQL执行超时时间,单位是毫秒。
     */
    private long slowSqlTimeout = 1000;

    /**
     * 连接数过多报警开关。
     */
    private boolean tooManyActiveConnectionEnabled = true;

    /**
     * 最低允许的活跃连接占比。不能过低,以免导致报警过于频繁。
     */
    public static final float MIN_MAX_ACTIVE_CONNECTION_RATIO = 0.3f;

    /**
     * 连接数过多报警。活跃连接占比允许的最大值,超过该值将会报警。
     */
    private float maxActiveConnectionRatio = 0.7f;

    // 报警
    private Monitor monitor = new DefaultMonitor();

    private static volatile AtomicLong tooManyActiveConnectionAlarmTimes = new AtomicLong(0);
    private static volatile Date tooManyActiveConnectionLatestAlarmTime = null;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        MappedStatement statement = (MappedStatement) args[0];
        Object parameterObject = null;
        String sqlId = null;
        BoundSql sql = null;
        try {
            parameterObject = args[1];
            sqlId = statement.getId();
            sql = statement.getBoundSql(parameterObject);

            //执行sql。
            long startMs = System.currentTimeMillis();
            Object result = invocation.proceed();
            long endMs = System.currentTimeMillis();
            long usedTimeInMs = endMs - startMs;

            try {

                //慢sql监控。
                if (slowSqlEnabled) {
                    if (usedTimeInMs > this.slowSqlTimeout) {
                        if (logger.isWarnEnabled()) {
                            String params = toString(parameterObject);
                            logger.warn("sqlId={}执行耗时{}毫秒,超过阀值{},执行的sql语句是[{}], 参数值:[{}]", sqlId, usedTimeInMs, slowSqlTimeout, sql.getSql(), params);
                        }

                        // 模拟报警
                        monitor.alarm(sqlId + "耗时" + usedTimeInMs + "毫秒,超过阀值" + this.slowSqlTimeout);
                    }
                }

                //连接过多监控。
                if (tooManyActiveConnectionEnabled) {
                    BasicDataSource basicDataSource = getBasicDataSource(statement);
                    if(basicDataSource != null) {
                        int connectionCount = basicDataSource.getNumActive();
                        float ratio = (connectionCount * 1.0f) / basicDataSource.getMaxActive();
                        if (ratio >= this.maxActiveConnectionRatio) {
                            logger.warn("数据库连接数过多,使用率已经超过了{}%, 当前活跃连接数{},允许最大活跃连接数{}, 检测连接池使用率超比例次数 {}", (this.maxActiveConnectionRatio * 100), basicDataSource.getNumActive(), basicDataSource.getMaxActive(),
                                    tooManyActiveConnectionAlarmTimes.incrementAndGet());

                            if (shouldAlarm(tooManyActiveConnectionLatestAlarmTime)) {
                                tooManyActiveConnectionLatestAlarmTime = new Date();

                                // 模拟报警
                                monitor.alarm("数据库连接数过多,使用率已经超过了" + (this.maxActiveConnectionRatio * 100)
                                        + "%, 当前活跃连接数" + basicDataSource.getNumActive()
                                        + ",允许最大活跃连接数" + basicDataSource.getMaxActive()
                                        + ", 检测到连接池使用超比例次数" + tooManyActiveConnectionAlarmTimes.get());
                            }
                        }
                    }
                }
            } catch (Throwable t) {
                logger.error("数据库监控插件出现异常{}", t);
            }

            return result;
        } catch (Throwable e) {
            //sql执行异常报警。
            if (sqlExceptionEnabled) {
                Throwable targetException = e;
                if (e instanceof InvocationTargetException) {
                    InvocationTargetException exception = (InvocationTargetException) e;
                    targetException = exception.getTargetException();
                }
                if (targetException != null) {
                    String paramVal = toString(parameterObject);
                    BasicDataSource basicDataSource = getBasicDataSource(statement);
                    int maxActive, active;
                    if (basicDataSource == null) {
                        maxActive = active = -1;
                    } else {
                        maxActive = basicDataSource.getMaxActive();
                        active = basicDataSource.getNumActive();
                    }
                    logger.error("执行SQL异常,sqlId={}, sql={}, parameter={}, maxActive={}, current={}",
                            sqlId,
                            (sql == null ? "UNKNOWN SQL" : sql.getSql()),
                            paramVal,
                            maxActive,
                            active);
                    logger.error("SQL异常", targetException);

                    // 模拟报警
                    monitor.alarm(sqlId + " sqlException,错误详细信息请查看日志");
                }
            }
            throw e;
        }
    }

    private boolean shouldAlarm(Date latestAlarmTime) {
        if (latestAlarmTime == null) {
            return true;
        }
        Date now = new Date();
        return now.getTime() - latestAlarmTime.getTime() > 5 * 60 * 1000;
    }

    private String toString(Object parameterObject) {
        String params;
        if (parameterObject == null) {
            params = "null";
        } else {
            params = JSONUtil.bean2Json(parameterObject);
        }
        return params;
    }

    private BasicDataSource getBasicDataSource(MappedStatement statement) {
        DataSource dataSource = statement.getConfiguration().getEnvironment().getDataSource();
        if (dataSource instanceof BasicDataSource) {
            return (BasicDataSource) dataSource;
        }
        return null;
    }

    public void setSqlExceptionEnabled(boolean sqlExceptionEnabled) {
        this.sqlExceptionEnabled = sqlExceptionEnabled;
    }

    public void setSlowSqlEnabled(boolean slowSqlEnabled) {
        this.slowSqlEnabled = slowSqlEnabled;
    }

    public void setSlowSqlTimeout(long slowSqlTimeout) {
        if (slowSqlTimeout >= MIN_SLOW_SQL_TIMEOUT) {
            this.slowSqlTimeout = slowSqlTimeout;
        }
    }

    public void setTooManyActiveConnectionEnabled(boolean tooManyActiveConnectionEnabled) {
        this.tooManyActiveConnectionEnabled = tooManyActiveConnectionEnabled;
    }

    public void setMaxActiveConnectionRatio(float maxActiveConnectionRatio) {
        if (maxActiveConnectionRatio >= MIN_MAX_ACTIVE_CONNECTION_RATIO) {
            if (maxActiveConnectionRatio > 1.0) {
                throw new RuntimeException("maxActiveConnectionRatio must between 0.3 to 1.0 and greater than 0.3");
            } else {
                this.maxActiveConnectionRatio = maxActiveConnectionRatio;
            }
        }
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {
        logger.debug("set properties for {} mybatis plugin", this.getClass().getName());
        //解析属性配置值,设置到对应的拦截器中。
        Set<String> keys = properties.stringPropertyNames();
        for (String key : keys) {
            String value = properties.getProperty(key);
            if (value != null && value.length() > 0) {
                try {
                    BeanUtils.setProperty(this, key, value);
                } catch (Throwable e) {
                    logger.error("属性值设置出错,请检查属性" + key + "的配置是否支持,或者属性的值类型不正确。");
                    throw new RuntimeException("configure property " + key + " error", e);
                }
            }
        }

        logger.debug("set properties end", this.getClass().getName());
    }
}


坚持原创技术分享,您的支持将鼓励我继续创作!



分享

评论