Пишем коннектор: read path, write path, pushdown
Предыдущий урок показал архитектуру DSv2 изнутри. Теперь переходим к практике: напишем полноценный коннектор для HTTP API с поддержкой read/write и несколькими уровнями pushdown. Это реалистичный сценарий — именно так устроены коннекторы к REST-сервисам, NoSQL-базам и custom data stores.
Наш коннектор будет читать и писать данные из воображаемого HTTP-сервиса, который возвращает JSON. Схема таблицы: id BIGINT, name STRING, amount DOUBLE, category STRING, ts TIMESTAMP. Сервис поддерживает фильтрацию по параметрам запроса, что позволит нам реализовать pushdown.
Точка входа: DataSourceRegister и TableProvider
Сначала регистрируем коннектор под именем http-api:
// В resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
// com.example.connector.HttpApiDataSource
package com.example.connector
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.{StructType, LongType, StringType, DoubleType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import java.util
class HttpApiDataSource extends TableProvider {
// Схема фиксирована нашим API
override def inferSchema(options: CaseInsensitiveStringMap): StructType =
HttpApiDataSource.SCHEMA
override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] =
Array.empty // Партиционирования нет -- API не поддерживает
override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String]): Table = {
// options содержат: baseUrl, numPartitions, apiKey, timeout
new HttpApiTable(schema, new CaseInsensitiveStringMap(properties))
}
override def supportsExternalMetadata(): Boolean = true
// true: Spark не будет вызывать inferSchema самостоятельно
// Пользователь может указать schema явно через .schema(...)
}
object HttpApiDataSource {
val SCHEMA: StructType = StructType(Seq(
StructField("id", LongType, nullable = false),
StructField("name", StringType, nullable = true),
StructField("amount", DoubleType, nullable = true),
StructField("category", StringType, nullable = true),
StructField("ts", TimestampType, nullable = true)
))
}
Table: объявляем capabilities
package com.example.connector
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import java.util
class HttpApiTable(
override val schema: StructType,
options: CaseInsensitiveStringMap
) extends Table with SupportsRead with SupportsWrite {
override def name(): String = s"http-api:${options.get("baseUrl")}"
override def partitioning(): Array[Transform] = Array.empty
override def properties(): util.Map[String, String] = options.asCaseSensitiveMap()
override def capabilities(): util.Set[TableCapability] =
util.EnumSet.of(
TableCapability.BATCH_READ,
TableCapability.BATCH_WRITE,
TableCapability.TRUNCATE
)
override def newScanBuilder(scanOptions: CaseInsensitiveStringMap): ScanBuilder = {
// Мержим опции таблицы с опциями скана
val merged = new CaseInsensitiveStringMap(
(options.asCaseSensitiveMap().asScala ++ scanOptions.asCaseSensitiveMap().asScala).asJava
)
new HttpApiScanBuilder(schema, merged)
}
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
new HttpApiWriteBuilder(options, info.schema())
}
ScanBuilder: планирование с pushdown
Это сердце read path. Реализуем два уровня pushdown: фильтры и column pruning.
package com.example.connector
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import scala.collection.mutable
class HttpApiScanBuilder(
private var schema: StructType,
options: CaseInsensitiveStringMap
) extends ScanBuilder
with SupportsPushDownRequiredColumns
with SupportsPushDownFilters {
// Фильтры, которые коннектор РЕАЛЬНО применит к HTTP-запросу
private val pushedFiltersBuffer = mutable.ArrayBuffer[Filter]()
// Фильтры, которые коннектор НЕ УМЕЕТ применить (Spark применит их сам)
private val postScanFiltersBuffer = mutable.ArrayBuffer[Filter]()
// 1. Column pruning -- вызывается Catalyst до build()
override def pruneColumns(requiredSchema: StructType): Unit = {
// Запоминаем только нужные поля
// Это уменьшит JSON, который мы парсим на executor
schema = requiredSchema
}
// 2. Filter pushdown -- вызывается Catalyst до build()
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
for (filter <- filters) {
if (canPushDown(filter)) {
pushedFiltersBuffer += filter
} else {
postScanFiltersBuffer += filter
}
}
// Возвращаем фильтры, которые Spark должен применить ПОСЛЕ чтения
postScanFiltersBuffer.toArray
}
override def pushedFilters(): Array[Filter] = pushedFiltersBuffer.toArray
// Проверяем, какие фильтры поддерживает наш HTTP API
private def canPushDown(filter: Filter): Boolean = filter match {
case EqualTo(col, _) if isSupportedColumn(col) => true
case GreaterThan(col, _) if isSupportedColumn(col) => true
case GreaterThanOrEqual(col, _) if isSupportedColumn(col) => true
case LessThan(col, _) if isSupportedColumn(col) => true
case LessThanOrEqual(col, _) if isSupportedColumn(col) => true
case In(col, _) if isSupportedColumn(col) => true
case IsNotNull(col) if isSupportedColumn(col) => true
case _ => false
// NOT, AND, OR -- не поддерживаем: наш API принимает только flat query params
}
// Наш API позволяет фильтровать только по индексированным полям
private val indexedColumns = Set("category", "amount", "id")
private def isSupportedColumn(col: String): Boolean =
indexedColumns.contains(col.toLowerCase)
override def build(): Scan =
new HttpApiScan(
schema,
options,
pushedFiltersBuffer.toArray
)
}
Scan и Batch: разбиение на партиции
package com.example.connector
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
// Scan -- сериализуемый снэпшот параметров чтения
class HttpApiScan(
val readSchema: StructType,
val options: CaseInsensitiveStringMap,
val pushedFilters: Array[Filter]
) extends Scan with Serializable {
override def toBatch(): Batch = new HttpApiBatch(this)
override def description(): String = {
val filterDesc = pushedFilters.map(_.toString).mkString(", ")
s"HttpApiScan[schema=${readSchema.fieldNames.mkString(",")}, filters=[$filterDesc]]"
}
}
class HttpApiBatch(scan: HttpApiScan) extends Batch {
override def planInputPartitions(): Array[InputPartition] = {
val baseUrl = scan.options.get("baseUrl")
val numPartitions = scan.options.getOrDefault("numPartitions", "4").toInt
// Делим данные на страницы (page-based partitioning)
// Каждая партиция = один HTTP-запрос с параметрами page и pageSize
val pageSize = scan.options.getOrDefault("pageSize", "1000").toInt
// Получаем общее количество записей (HEAD-запрос)
val totalCount = HttpApiClient.getCount(baseUrl, scan.pushedFilters)
val actualNumPartitions = math.min(numPartitions,
math.ceil(totalCount.toDouble / pageSize).toInt.max(1))
val recordsPerPartition = math.ceil(totalCount.toDouble / actualNumPartitions).toInt
(0 until actualNumPartitions).map { i =>
new HttpApiInputPartition(
baseUrl = baseUrl,
offset = i * recordsPerPartition,
limit = recordsPerPartition,
filters = scan.pushedFilters,
schema = scan.readSchema,
apiKey = scan.options.get("apiKey")
).asInstanceOf[InputPartition]
}.toArray
}
override def createReaderFactory(): PartitionReaderFactory =
new HttpApiReaderFactory()
}
// InputPartition -- сериализуемое описание одного HTTP-запроса
// Держим ТОЛЬКО примитивы + сериализуемые типы
case class HttpApiInputPartition(
baseUrl: String,
offset: Long,
limit: Int,
filters: Array[Filter], // Filter -- сериализуемый через Java Serialization
schema: StructType, // StructType -- сериализуемый
apiKey: String
) extends InputPartition
planInputPartitions() вызывается на driver, и вы видите соблазн сделать здесь “умное” разбиение: сходить в БД, узнать реальный размер каждой партиции, оптимально распределить нагрузку. Это допустимо, но опасно: если этот метод занимает 30 секунд, они вычитаются из job timeout. Для Iceberg с миллионом файлов это реальная проблема — держите planInputPartitions() как можно быстрее, либо используйте параллельный планировщик с ExecutorService.
PartitionReaderFactory и PartitionReader
package com.example.connector
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.read.{PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import java.io.IOException
import scala.collection.mutable
class HttpApiReaderFactory extends PartitionReaderFactory with Serializable {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] =
new HttpApiPartitionReader(partition.asInstanceOf[HttpApiInputPartition])
}
class HttpApiPartitionReader(partition: HttpApiInputPartition)
extends PartitionReader[InternalRow] {
// HTTP-клиент -- создаётся лениво на executor, не сериализуется
private lazy val client = new HttpApiClient(partition.baseUrl, partition.apiKey)
// Буфер для текущей страницы (чтобы не держать весь ответ в памяти)
private val pageBuffer = mutable.ArrayBuffer[InternalRow]()
private var bufferIndex = 0
private var currentOffset = partition.offset
private val fetchSize = 200 // Читаем по 200 строк за раз
private var exhausted = false
// Переиспользуем один объект InternalRow (mutable!)
// Это КРИТИЧЕСКИ важно для производительности -- не создавайте new InternalRow на каждый get()
private val reusableRow = new GenericInternalRow(partition.schema.length)
override def next(): Boolean = {
if (bufferIndex < pageBuffer.length) return true
if (exhausted) return false
// Нужно подгрузить следующую страницу
fetchNextPage()
bufferIndex < pageBuffer.length
}
private def fetchNextPage(): Unit = {
val remaining = partition.limit - (currentOffset - partition.offset)
if (remaining <= 0) {
exhausted = true
return
}
val toFetch = math.min(fetchSize, remaining.toInt)
val records = client.fetchRecords(
offset = currentOffset,
limit = toFetch,
filters = partition.filters
)
pageBuffer.clear()
bufferIndex = 0
for (record <- records) {
val row = buildInternalRow(record, partition.schema)
pageBuffer += row
}
currentOffset += records.length
if (records.length < toFetch) {
exhausted = true // API вернул меньше, чем запрашивали -- конец данных
}
}
private def buildInternalRow(record: Map[String, Any], schema: StructType): InternalRow = {
val values = schema.fields.map { field =>
record.get(field.name) match {
case None | Some(null) => null
case Some(v) => field.dataType match {
case LongType => v.asInstanceOf[Number].longValue()
case DoubleType => v.asInstanceOf[Number].doubleValue()
case StringType => UTF8String.fromString(v.toString)
// Timestamp в Spark хранится как микросекунды от эпохи
case TimestampType =>
DateTimeUtils.millisToMicros(v.asInstanceOf[Long])
case _ => v
}
}
}
new GenericInternalRow(values.asInstanceOf[Array[Any]])
}
override def get(): InternalRow = pageBuffer(bufferIndex - 1 + {bufferIndex += 1; 0})
// Упрощённо: в реальном коде bufferIndex++ должен быть в next()
override def close(): Unit = {
try {
client.close()
} catch {
case e: IOException =>
// Логируем, но не бросаем -- close() не должен прерывать cleanup
logWarning(s"Failed to close HTTP client: ${e.getMessage}")
}
}
}
Write Path: от WriteBuilder до DataWriter
Read path мы реализовали. Теперь write path — запись данных через API.
package com.example.connector
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class HttpApiWriteBuilder(
options: CaseInsensitiveStringMap,
schema: StructType
) extends WriteBuilder {
override def buildForBatch(): BatchWrite =
new HttpApiBatchWrite(options, schema)
}
class HttpApiBatchWrite(
options: CaseInsensitiveStringMap,
schema: StructType
) extends BatchWrite {
override def createBatchWriterFactory(
physicalWriteInfo: PhysicalWriteInfo
): DataWriterFactory =
new HttpApiDataWriterFactory(
baseUrl = options.get("baseUrl"),
apiKey = options.get("apiKey"),
schema = schema,
batchSize = options.getOrDefault("writeBatchSize", "100").toInt
)
// Вызывается после того как все DataWriter'ы завершили работу
override def commit(messages: Array[WriterCommitMessage]): Unit = {
val totalWritten = messages.collect {
case msg: HttpApiCommitMessage => msg.recordsWritten
}.sum
logInfo(s"Committed $totalWritten records to HTTP API")
// Можно здесь сделать финальный API-вызов для "завершения" транзакции
}
// Вызывается при ошибке -- нужно откатить записанные данные
override def abort(messages: Array[WriterCommitMessage]): Unit = {
val successIds = messages.collect {
case msg: HttpApiCommitMessage => msg.batchIds
}.flatten
// Пытаемся удалить уже записанные батчи
if (successIds.nonEmpty) {
logWarning(s"Aborting write, rolling back batch IDs: ${successIds.mkString(",")}")
// HttpApiClient.deleteBatches(options.get("baseUrl"), successIds)
}
}
}
// Сериализуемое сообщение от executor к driver
case class HttpApiCommitMessage(
partitionId: Int,
recordsWritten: Long,
batchIds: Array[String]
) extends WriterCommitMessage
class HttpApiDataWriterFactory(
baseUrl: String,
apiKey: String,
schema: StructType,
batchSize: Int
) extends DataWriterFactory with Serializable {
override def createWriter(
partitionId: Int,
taskId: Long
): DataWriter[InternalRow] =
new HttpApiDataWriter(baseUrl, apiKey, schema, batchSize, partitionId)
}
class HttpApiDataWriter(
baseUrl: String,
apiKey: String,
schema: StructType,
batchSize: Int,
partitionId: Int
) extends DataWriter[InternalRow] {
private lazy val client = new HttpApiClient(baseUrl, apiKey)
private val buffer = mutable.ArrayBuffer[Map[String, Any]]()
private val committedBatchIds = mutable.ArrayBuffer[String]()
private var totalWritten = 0L
override def write(record: InternalRow): Unit = {
buffer += internalRowToMap(record, schema)
if (buffer.length >= batchSize) {
flushBuffer()
}
}
private def flushBuffer(): Unit = {
if (buffer.isEmpty) return
val batchId = client.postRecords(buffer.toList)
committedBatchIds += batchId
totalWritten += buffer.length
buffer.clear()
}
override def commit(): WriterCommitMessage = {
flushBuffer() // Записываем оставшиеся данные
client.close()
HttpApiCommitMessage(partitionId, totalWritten, committedBatchIds.toArray)
}
override def abort(): Unit = {
buffer.clear()
// При aborte не делаем flush -- отбрасываем буферизованные данные
// Уже записанные batch ID передаём driver'у через специальный механизм
// (в DSv2 abort не возвращает WriterCommitMessage, поэтому используем логирование)
client.close()
}
override def close(): Unit = {
// close() вызывается в finally -- не бросаем исключения
try { client.close() } catch { case _: Exception => }
}
}
Pushdown в деталях: что и как работает
Рассмотрим жизненный цикл pushdown на конкретном запросе:
# Python-код пользователя
df = spark.read.format("http-api") \
.option("baseUrl", "https://api.example.com/data") \
.option("numPartitions", "8") \
.option("apiKey", "secret") \
.load()
result = df.filter(
(df.category == "electronics") &
(df.amount > 50.0) &
(df.name.contains("Pro")) # LIKE -- не поддерживается нашим API
).select("id", "name", "amount")
result.explain("formatted")
Вот что происходит:
Шаг 1: Analyzer создаёт узел DataSourceV2Relation для нашего коннектора.
Шаг 2: Правило V2ScanRelationPushDown находит DataSourceV2Relation и создаёт экземпляр HttpApiScanBuilder. Затем передаёт ему фильтры:
pushFilters([
EqualTo(category, "electronics"), // -> можно pushdown (indexed)
GreaterThan(amount, 50.0), // -> можно pushdown (indexed)
Contains(name, "Pro") // -> нельзя (не indexed, сложный фильтр)
])
pushFilters() вернёт [Contains(name, "Pro")] — Spark применит его после чтения. Два других фильтра перейдут в pushedFiltersBuffer.
Шаг 3: Column pruning — pruneColumns(StructType(id, name, amount)).
Шаг 4: build() создаёт HttpApiScan с readSchema = {id, name, amount} и pushedFilters = [EqualTo(category,"electronics"), GreaterThan(amount, 50.0)].
Шаг 5: Физическое планирование — DataSourceV2Strategy превращает всё это в BatchScanExec с вышестоящим Filter(Contains(name, "Pro")).
== Physical Plan ==
*(1) Filter Contains(name#1, Pro)
+- *(1) BatchScan[id#0, name#1, amount#2]
class: com.example.connector.HttpApiTable
filters: [Contains(name#1, Pro)]
pushed filters: [EqualTo(category, electronics), GreaterThan(amount, 50.0)]
ReadSchema: struct<id:bigint,name:string,amount:double>
На HTTP-уровне каждый запрос от HttpApiInputPartition будет выглядеть как:
GET /data?category=electronics&amount_gt=50.0&offset=0&limit=1000
Фильтр по name применяется Spark-ом поверх результата.
Полный граф pushdown в Spark 4.0
Aggregate pushdown: самый эффективный pushdown
SupportsPushDownAggregates — мощнейший инструмент. Если коннектор умеет агрегировать на стороне сервера, запрос SELECT category, COUNT(*), SUM(amount) FROM t GROUP BY category превращается в один API-вызов, возвращающий несколько строк вместо миллионов:
class HttpApiScanBuilder(...)
extends ScanBuilder
with SupportsPushDownAggregates {
private var aggregation: Option[Aggregation] = None
private var groupBySchema: StructType = _
override def pushAggregation(agg: Aggregation): Boolean = {
// Проверяем, умеет ли наш API делать агрегацию
val supportedAggFunctions = agg.aggregateExpressions().forall {
case _: Count => true
case _: Sum => true
case _: Max => true
case _: Min => true
case _ => false // AVG, STDDEV и т.д. не поддерживаем
}
val onlySupportedGroupBy = agg.groupByExpressions().forall {
case NamedReference(parts) => isSupportedColumn(parts.head)
case _ => false
}
if (supportedAggFunctions && onlySupportedGroupBy) {
aggregation = Some(agg)
// Меняем readSchema: теперь возвращаем агрегированные колонки
// порядок: сначала groupBy-колонки, потом аgg-колонки
groupBySchema = buildAggregatedSchema(agg)
true // "да, я применил агрегацию"
} else {
false // "нет, Spark должен агрегировать сам"
}
}
override def build(): Scan = {
val finalSchema = if (aggregation.isDefined) groupBySchema else schema
new HttpApiScan(finalSchema, options, pushedFiltersBuffer.toArray, aggregation)
}
}
При успешном aggregate pushdown Spark UI покажет:
BatchScan[category#3, count(1)#4L, sum(amount#2)#5]
pushed aggregation: Aggregation(groupBy=[category], agg=[count(*), sum(amount)])
Тестирование коннектора
Правильное тестирование DSv2-коннектора требует Unit-тестов для каждого компонента:
class HttpApiScanBuilderSpec extends AnyFlatSpec with Matchers {
"pushFilters" should "push supported filters and return unsupported" in {
val builder = new HttpApiScanBuilder(HttpApiDataSource.SCHEMA, emptyOptions)
val filters = Array[Filter](
EqualTo("category", "electronics"), // поддерживается
Contains("name", "Pro"), // не поддерживается
GreaterThan("amount", 50.0) // поддерживается
)
val postScanFilters = builder.pushFilters(filters)
// Только Contains должен вернуться как post-scan
postScanFilters should have length 1
postScanFilters.head shouldBe a [Contains]
// Два фильтра должны быть pushed
builder.pushedFilters() should have length 2
}
"pruneColumns" should "reduce readSchema" in {
val builder = new HttpApiScanBuilder(HttpApiDataSource.SCHEMA, emptyOptions)
val requiredSchema = StructType(Seq(
StructField("id", LongType),
StructField("amount", DoubleType)
))
builder.pruneColumns(requiredSchema)
val scan = builder.build().asInstanceOf[HttpApiScan]
scan.readSchema.fieldNames should contain only ("id", "amount")
}
}
Попробуй сам
Возьмём доступный в Spark коннектор — JDBC — и исследуем его pushdown через explain:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as spark_sum, count
spark = SparkSession.builder \
.appName("connector-pushdown-demo") \
.config("spark.jars", "/path/to/postgresql-driver.jar") \
.getOrCreate()
# Загружаем через JDBC -- это DSv2-совместимый коннектор
df = spark.read \
.format("jdbc") \
.option("url", "jdbc:postgresql://localhost/mydb") \
.option("dbtable", "orders") \
.option("user", "postgres") \
.option("password", "secret") \
.option("numPartitionColumn", "id") \
.option("numPartitions", "4") \
.option("lowerBound", "1") \
.option("upperBound", "100000") \
.load()
# Запрос с filter + column pruning + aggregate
result = df.filter(col("amount") > 100) \
.select("category", "amount") \
.groupBy("category") \
.agg(spark_sum("amount").alias("total"))
result.explain("formatted")
# Ожидаемый вывод:
# BatchScan[category#0, amount#1]
# pushed filters: [GreaterThan(amount, 100.0)]
# pushed aggregate: SUM(amount) GROUP BY category
#
# Если JDBC-коннектор поддерживает aggregate pushdown, PostgreSQL выполнит
# GROUP BY + SUM на своей стороне -- Spark получит только результат агрегации!
# Посмотрим реальные SQL-запросы к PostgreSQL через логирование
spark.sparkContext.setLogLevel("DEBUG")
# В логах ищем строки с "SELECT ... FROM orders WHERE amount > 100 ..."
Для тестирования custom коннектора без реального сервера создайте Mock:
# Тест нашего коннектора через SparkSession
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("connector-test") \
.master("local[2]") \
.getOrCreate()
# Регистрируем коннектор (если jar в classpath)
df = spark.read.format("com.example.connector.HttpApiDataSource") \
.option("baseUrl", "http://mock-server:8080/api") \
.option("numPartitions", "2") \
.option("apiKey", "test-key") \
.load()
# Проверяем что pushdown работает
filtered = df.filter("category = 'electronics' AND amount > 50")
plan = filtered._jdf.queryExecution().executedPlan().toString()
assert "PushedFilters: [EqualTo(category, electronics), GreaterThan(amount, 50.0)]" in plan
print("Pushdown работает корректно")
Для локального тестирования write path используйте df.write.format("your-format").mode("append").save() и проверяйте commit(messages) — именно там видны все WriterCommitMessage от всех partition writers. Это хорошее место для проверки идемпотентности: запустите write дважды и убедитесь что результат корректен.