本文旨在简析 Spark 读取数据库的一些关键源码
Spark如何读取数据库数据
像其他的数据映射框架一样(如hibernate,mybatis等),spark如果想读取数据库数据,也绕不开JDBC链接,毕竟这是代码与数据库“交流”的官方途径。spark如果想快速读取数据库中的数据,需要解决的事情包括但不限于:
- 分布式读取
- 原始数据到RDD/DataFrame的映射
所以这篇小文主要围绕这两个方面做下源码的简析
关于spark操作数据库API,可以参考这篇文档:Spark JDBC系列--取数的四种方式
源码简析
1.JDBC API公共入口
入口源码:
org.apache.spark.sql.DataFrameReader
...
private def jdbc(
url: String,
table: String,
parts: Array[Partition],
connectionProperties: Properties): DataFrame = {
val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
//关键点
val relation = JDBCRelation(url, table, parts, props)(sparkSession)
//逻辑分区的创建,action后会触发读取
sparkSession.baseRelationToDataFrame(relation)
}
通过观察源码可知,四种取数API的参数虽然略有不同,但最终都转换成了一个Array[Partition]
,即分区条件数组。
2.指定column的取数API分区原理简析
此处列举提供long型column的分区模式的API
的分区原理,先看源码:
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
if (partitioning == null || partitioning.numPartitions <= 1 ||
partitioning.lowerBound == partitioning.upperBound) {
//单分区模式会进入此条件
return Array[Partition](JDBCPartition(null, 0))
}
//合法性校验
val lowerBound = partitioning.lowerBound
val upperBound = partitioning.upperBound
....
//分区调整
val numPartitions =
if ((upperBound - lowerBound) >= partitioning.numPartitions) {
partitioning.numPartitions
} else {
upperBound - lowerBound
}
//计算步长
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
val column = partitioning.column
var i: Int = 0
var currentValue: Long = lowerBound
var ans = new ArrayBuffer[Partition]()
//根据步长,根据提供的最大、最小值做步长累计,确定边界后组装where查询条件
while (i < numPartitions) {
//注意此处,会存在单边限制条件的情况,如:JDBCPartition(id >= 901,9)
val lBound = if (i != 0) s"$column >= $currentValue" else null
currentValue += stride
val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
val whereClause =
if (uBound == null) {
lBound
} else if (lBound == null) {
s"$uBound or $column is null"
} else {
s"$lBound AND $uBound"
}
ans += JDBCPartition(whereClause, i)
i = i + 1
}
ans.toArray
}
测试代码与分区结果如下:
入参为:
lowerBound=1, upperBound=1000, numPartitions=10
对应分区数组为:
JDBCPartition(id < 101 or id is null,0),
JDBCPartition(id >= 101 AND id < 201,1),
JDBCPartition(id >= 201 AND id < 301,2),
JDBCPartition(id >= 301 AND id < 401,3),
JDBCPartition(id >= 401 AND id < 501,4),
JDBCPartition(id >= 501 AND id < 601,5),
JDBCPartition(id >= 601 AND id < 701,6),
JDBCPartition(id >= 701 AND id < 801,7),
JDBCPartition(id >= 801 AND id < 901,8),
JDBCPartition(id >= 901,9)
这种使用方式存在误用场景
,即通过指定一段ID的最大最小值
(而非整张表真正的最大最小值去取数据),则依然会取出全表数据
,且发生数据倾斜
,原因就在于第一个分区和最后一个分区的where条件处理,所以如果需要指定范围或更多条件,建议使用支持自定义分区条件
的API。
3.数据结果映射
函数:
org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
//获取dataframe的schema,即对数据库的字段类型和spark的数据类型做映射
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
//具体实现
org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
def resolveTable(url: String, table: String, properties: Properties): StructType = {
//url中识别出需要使用的方言
val dialect = JdbcDialects.get(url)
val ncols = rsmd.getColumnCount
val fields = new Array[StructField](ncols)
var i = 0
....
while (i < ncols) {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
....
//根据不同方言的约定做映射,未找到时使用默认映射规则
val columnType =dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
}
return new StructType(fields)
字段映射的默认配置例举:
val answer = sqlType match {
....
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
case java.sql.Types.CLOB => StringType
case java.sql.Types.DATALINK => null
case java.sql.Types.DATE => DateType
case java.sql.Types.DECIMAL
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.DISTINCT => null
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
....
}
此处例举MySQL的方言实现:
所有的方言实现都此包下:org.apache.spark.sql.jdbc.*,实现请自行参考。
MySQL方言:
private case object MySQLDialect extends JdbcDialect {
override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
//关键实现
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
// This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
// byte arrays instead of longs.
md.putLong("binarylong", 1)
Option(LongType)
} else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
Option(BooleanType)
} else None
}
....
}
从源码可以看出,MySQL只对bit和tinyint类型进行了约束,其他类型使用了spark的默认配置,所以在读取数据时,需要考虑spark中的方言映射,是否对已存在的数据造成影响,避免数据失真。
此时 JDBCRelation
对象已经完成构造。
4.RDD构造与逻辑分区生成
根据之前生成的 JDBCRelation,sparkSession会把任务加入逻辑执行计划。当遇到action操作时,会转为物理执行计划,
org.apache.spark.sql.SparkSession
//逻辑执行计划构建,细节不写了,源码我也没怎么研究过
def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
Dataset.ofRows(self, LogicalRelation(baseRelation))
}
org.apache.spark.sql.execution.datasources.DataSourceStrategy
//物理执行计划
object DataSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation.....
//JDBCRelation继承了PrunedFilteredScan,进入此case分支,并调用buildScan方法
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) =>
pruneFilterProject(
l,
projects,
filters,
(a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil
case PhysicalOperation.....
}
JDBCRelation
的 buildScan 方法执行时,会调用JDBCRDD
的 scanTable 方法新建 RDD,其中计算前加入的 filter 条件,会合并到JDBC查询where条件中,使用AND
连接:
private[jdbc] class JDBCRDD(
sc: SparkContext,
getConnection: () => Connection,
schema: StructType,
fqTable: String,
columns: Array[String],
filters: Array[Filter],
partitions: Array[Partition],
url: String,
properties: Properties)
extends RDD[InternalRow](sc, Nil) {
override def getPartitions: Array[Partition] = partitions
.....
private def getWhereClause(part: JDBCPartition): String = {
if (part.whereClause != null && filterWhereClause.length > 0) {
"WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
} else if (part.whereClause != null) {
"WHERE " + part.whereClause
} else if (filterWhereClause.length > 0) {
"WHERE " + filterWhereClause
} else {
""
}
}
//compute方法为action触发时,执行的SQL语句,并对结果按之前的约定做数据映射
override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] =
new Iterator[InternalRow] {
。。。。
//实现细节不再展开,主要是JDBC查询操作和数据类型映射
}
filter条件使用示例:
val url = "jdbc:mysql://mysqlHost:3306/database"
val tableName = "table"
val columnName = "id"
val lowerBound = getMinId()
val upperBound = getMaxId()
val numPartitions = 200
// 设置连接用户&密码
val prop = new java.util.Properties
prop.setProperty("user","username")
prop.setProperty("password","pwd")
// 对mysql数据进行过滤
val jdbcDF = sqlContext.read.jdbc(url,tableName, columnName, lowerBound, upperBound,prop).where("date='2017-11-30'").filter("name is not null")
where 和 filter 是等价的,过滤条件将在 where 语句中生效,多个条件会用And
进行拼接。
结语
读取数据库数据时,可以到对应的源码中,debug分析。