jdbc - SPARK SQL - 使用DataFrames和JDBC更新MySql表

我尝试使用Spark DataFrames和JDBC连接在MySql,插入和更新一些数据。

我已经成功地使用SaveMode.Append插入了新数据,有没有办法更新MySql表中已经存在的数据?

插入的代码是:

myDataFrame.write.mode(SaveMode.Append).jdbc(JDBCurl,mySqlTable,connectionProperties)

如果我更改为SaveMode.Overwrite,它会删除整个表,并且创建一个新的表,我正在查找类似MySql的"ON DUPLICATE KEY UPDATE"

时间:

现在不可能,(Spark 1.6.0/2.2.0快照)Spark DataFrameWriter只支持四种写入模式:

  • SaveMode.Overwrite :覆盖现有数据,
  • SaveMode.Append :附加数据,
  • SaveMode.Ignore :忽略该操作(即,无操作),
  • SaveMode.ErrorIfExists :缺省选项,在运行时引发异常,

可以手动插入,例如,使用mapPartitions (因为你想要一个更新的操作应该是UPSERT,而且易于实现),写入临时表,并且手动执行插件,或使用触发器。

你必须记住,在一般情况下,(每个分区一个)会有多个并发事务,因此你必须确保没有写冲突(通常使用应用程序特定分区),实际上,在数据库中执行和批量写入临时表,并且部分更新可能更好。

我只是想补充一下,你可以使用jaydebapi包来解决这个问题: https://pypi.python.org/pypi/JayDeBeApi/

更新mysql表中的数据,

JayDeBeApi模块允许你使用Java JDBC从python代码连接到数据库,它为数据库提供了Python DB-API v2.0。

我们使用Anaconda分发的Python,而JayDeBeApi Python包是标准配置。

遗憾的是,对于像upserting这样的常见情况,Spark中没有SaveMode.Upsert模式。

我还想为这个案例提供一些java代码,只需根据需要进行修改:


myDF.repartition(20); //one connection per partition, see below



myDF.foreachPartition((Iterator<Row> t) -> {


 Connection conn = DriverManager.getConnection(


 Constants.DB_JDBC_CONN,


 Constants.DB_JDBC_USER,


 Constants.DB_JDBC_PASS);



 conn.setAutoCommit(true);


 Statement statement = conn.createStatement();



 final int batchSize = 100000;


 int i = 0;


 while (t.hasNext()) {


 Row row = t.next();


 try {


 // better than REPLACE INTO, less cycles


 statement.addBatch(("INSERT INTO mytable" +"VALUES ("


 +"'" + row.getAs("_id") +"', 


 +"'" + row.getStruct(1).get(0) +"'


 +"') ON DUPLICATE KEY UPDATE _id='" + row.getAs("_id") +"';"));


 //conn.commit();



 if (++i % batchSize == 0) {


 statement.executeBatch();


 }


 } catch (SQLIntegrityConstraintViolationException e) {


 //should not occur, nevertheless


 //conn.commit();


 } catch (SQLException e) {


 e.printStackTrace();


 } finally {


 //conn.commit();


 statement.executeBatch();


 }


 }


 int[] ret = statement.executeBatch();



 System.out.println("Ret val:" + Arrays.toString(ret));


 System.out.println("Update count:" + statement.getUpdateCount());


 conn.commit();



 statement.close();


 conn.close();



覆盖 org.apache.spark.sql.execution.datasources.jdbcJdbcUtils.scalainsert intoreplace into


import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, SQLException}



import scala.collection.JavaConverters._


import scala.util.control.NonFatal


import com.typesafe.scalalogging.Logger


import org.apache.spark.sql.catalyst.InternalRow


import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper, JDBCOptions}


import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}


import org.apache.spark.sql.types._


import org.apache.spark.sql.{DataFrame, Row}



/**


 * Util functions for JDBC tables.


 */


object UpdateJdbcUtils {



 val logger = Logger(this.getClass)



 /**


 * Returns a factory for creating connections to the given JDBC URL.


 *


 * @param options - JDBC options that contains url, table and other information.


 */


 def createConnectionFactory(options: JDBCOptions): () => Connection = {


 val driverClass: String = options.driverClass


 () => {


 DriverRegistry.register(driverClass)


 val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {


 case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d


 case d if d.getClass.getCanonicalName == driverClass => d


 }.getOrElse {


 throw new IllegalStateException(


 s"Did not find registered driver with class $driverClass")


 }


 driver.connect(options.url, options.asConnectionProperties)


 }


 }



 /**


 * Returns a PreparedStatement that inserts a row into table via conn.


 */


 def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)


 : PreparedStatement = {


 val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")


 val placeholders = rddSchema.fields.map(_ =>"?").mkString(",")


 val sql = s"REPLACE INTO $table ($columns) VALUES ($placeholders)"


 conn.prepareStatement(sql)


 }



 /**


 * Retrieve standard jdbc types.


 *


 * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])


 * @return The default JdbcType for this DataType


 */


 def getCommonJDBCType(dt: DataType): Option[JdbcType] = {


 dt match {


 case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))


 case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))


 case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))


 case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))


 case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))


 case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))


 case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))


 case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))


 case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))


 case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))


 case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))


 case t: DecimalType => Option(


 JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))


 case _ => None


 }


 }



 private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {


 dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(


 throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))


 }



 // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field


 // for `MutableRow`. The last argument `Int` means the index for the value to be set in


 // the row and also used for the value in `ResultSet`.


 private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit



 // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for


 // `PreparedStatement`. The last argument `Int` means the index for the value to be set


 // in the SQL statement and also used for the value in `Row`.


 private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit



 /**


 * Saves a partition of a DataFrame to the JDBC database. This is done in


 * a single database transaction (unless isolation level is"NONE")


 * in order to avoid repeatedly inserting data as much as possible.


 *


 * It is still theoretically possible for rows in a DataFrame to be


 * inserted into the database more than once if a stage somehow fails after


 * the commit occurs but before the stage can return successfully.


 *


 * This is not a closure inside saveTable() because apparently cosmetic


 * implementation changes elsewhere might easily render such a closure


 * non-Serializable. Instead, we explicitly close over all variables that


 * are used.


 */


 def savePartition(


 getConnection: () => Connection,


 table: String,


 iterator: Iterator[Row],


 rddSchema: StructType,


 nullTypes: Array[Int],


 batchSize: Int,


 dialect: JdbcDialect,


 isolationLevel: Int): Iterator[Byte] = {


 val conn = getConnection()


 var committed = false



 var finalIsolationLevel = Connection.TRANSACTION_NONE


 if (isolationLevel != Connection.TRANSACTION_NONE) {


 try {


 val metadata = conn.getMetaData


 if (metadata.supportsTransactions()) {


 // Update to at least use the default isolation, if any transaction level


 // has been chosen and transactions are supported


 val defaultIsolation = metadata.getDefaultTransactionIsolation


 finalIsolationLevel = defaultIsolation


 if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {


 // Finally update to actually requested level if possible


 finalIsolationLevel = isolationLevel


 } else {


 logger.warn(s"Requested isolation level $isolationLevel is not supported;" +


 s"falling back to default isolation level $defaultIsolation")


 }


 } else {


 logger.warn(s"Requested isolation level $isolationLevel, but transactions are unsupported")


 }


 } catch {


 case NonFatal(e) => logger.warn("Exception while detecting transaction support", e)


 }


 }


 val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE



 try {


 if (supportsTransactions) {


 conn.setAutoCommit(false) // Everything in the same db transaction.


 conn.setTransactionIsolation(finalIsolationLevel)


 }


 val stmt = insertStatement(conn, table, rddSchema, dialect)


 val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)


 .map(makeSetter(conn, dialect, _))


 val numFields = rddSchema.fields.length



 try {


 var rowCount = 0


 while (iterator.hasNext) {


 val row = iterator.next()


 var i = 0


 while (i < numFields) {


 if (row.isNullAt(i)) {


 stmt.setNull(i + 1, nullTypes(i))


 } else {


 setters(i).apply(stmt, row, i)


 }


 i = i + 1


 }


 stmt.addBatch()


 rowCount += 1


 if (rowCount % batchSize == 0) {


 stmt.executeBatch()


 rowCount = 0


 }


 }


 if (rowCount > 0) {


 stmt.executeBatch()


 }


 } finally {


 stmt.close()


 }


 if (supportsTransactions) {


 conn.commit()


 }


 committed = true


 Iterator.empty


 } catch {


 case e: SQLException =>


 val cause = e.getNextException


 if (cause != null && e.getCause != cause) {


 if (e.getCause == null) {


 e.initCause(cause)


 } else {


 e.addSuppressed(cause)


 }


 }


 throw e


 } finally {


 if (!committed) {


 // The stage must fail. We got here through an exception path, so


 // let the exception through unless rollback() or close() want to


 // tell the user about another problem.


 if (supportsTransactions) {


 conn.rollback()


 }


 conn.close()


 } else {


 // The stage must succeed. We cannot propagate any exception close() might throw.


 try {


 conn.close()


 } catch {


 case e: Exception => logger.warn("Transaction succeeded, but closing failed", e)


 }


 }


 }


 }



 /**


 * Saves the RDD to the database in a single transaction.


 */


 def saveTable(


 df: DataFrame,


 url: String,


 table: String,


 options: JDBCOptions) {


 val dialect = JdbcDialects.get(url)


 val nullTypes: Array[Int] = df.schema.fields.map { field =>


 getJdbcType(field.dataType, dialect).jdbcNullType


 }



 val rddSchema = df.schema


 val getConnection: () => Connection = createConnectionFactory(options)


 val batchSize = options.batchSize


 val isolationLevel = options.isolationLevel


 df.foreachPartition(iterator => savePartition(


 getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)


 )


 }



 private def makeSetter(


 conn: Connection,


 dialect: JdbcDialect,


 dataType: DataType): JDBCValueSetter = dataType match {


 case IntegerType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setInt(pos + 1, row.getInt(pos))



 case LongType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setLong(pos + 1, row.getLong(pos))



 case DoubleType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setDouble(pos + 1, row.getDouble(pos))



 case FloatType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setFloat(pos + 1, row.getFloat(pos))



 case ShortType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setInt(pos + 1, row.getShort(pos))



 case ByteType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setInt(pos + 1, row.getByte(pos))



 case BooleanType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setBoolean(pos + 1, row.getBoolean(pos))



 case StringType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setString(pos + 1, row.getString(pos))



 case BinaryType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))



 case TimestampType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))



 case DateType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))



 case t: DecimalType =>


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 stmt.setBigDecimal(pos + 1, row.getDecimal(pos))



 case ArrayType(et, _) =>


 // remove type length parameters from end of type name


 val typeName = getJdbcType(et, dialect).databaseTypeDefinition


 .toLowerCase.split("(")(0)


 (stmt: PreparedStatement, row: Row, pos: Int) =>


 val array = conn.createArrayOf(


 typeName,


 row.getSeq[AnyRef](pos).toArray)


 stmt.setArray(pos + 1, array)



 case _ =>


 (_: PreparedStatement, _: Row, pos: Int) =>


 throw new IllegalArgumentException(


 s"Can't translate non-null value for field $pos")


 }


}



使用方法:


val url = s"jdbc:mysql://$host/$database?useUnicode=true&characterEncoding=UTF-8"



val parameters: Map[String, String] = Map(


"url" -> url,


"dbtable" -> table,


"driver" ->"com.mysql.jdbc.Driver",


"numPartitions" -> numPartitions.toString,


"user" -> user,


"password" -> password


)


val options = new JDBCOptions(parameters)



for (d <- data) {


 UpdateJdbcUtils.saveTable(d, url, table, options)


}



ps:注意死锁,不经常更新数据,只是在紧急情况下重新使用,我认为,这就是为什么spark 官方不支持这个。

...