transformWithState: новый stateful API
mapGroupsWithState и flatMapGroupsWithState появились в Spark 2.2 и долгое время были единственным способом реализовать arbitrary stateful logic в Structured Streaming. К Spark 4.0 накопился список из семи фундаментальных ограничений этих операторов. transformWithState — это полная переработка arbitrary state API, доступная в Spark 4.0 как стабильный API. В этом уроке разберём мотивацию, интерфейс, исполнение через TransformWithStateExec, таймеры, TTL и ключевые отличия от старого API.
Ограничения старого API
Перед тем как смотреть на новый API, нужно понять, что конкретно не работало в старом.
Ограничение 1: одна плоская структура state. mapGroupsWithState принимает единственный тип состояния S. Если нужно несколько независимых state-переменных (например, lastEventTime: Long и windowCounts: Map[String, Int] и activeSessionStart: Option[Long]), всё приходилось упаковывать в один case class. Это создаёт проблемы с эволюцией схемы — добавить поле без потери state нельзя.
Ограничение 2: один таймер на ключ. Оператор поддерживал только one expiry timestamp per key. Нельзя одновременно поставить «напомнить через 5 минут» и «очистить через 24 часа».
Ограничение 3: нет side outputs. Невозможно маршрутизировать ошибочные записи в отдельный поток (dead letter queue) и нормальные — в основной вывод.
Ограничение 4: нет chaining после оператора. После flatMapGroupsWithState нельзя ставить другой stateful оператор — это ограничение физического планировщика.
Ограничение 5: нет schema evolution. Изменение типа state требует полного перезапуска с чистым checkpoint.
Ограничение 6: смешанная логика. Обработка «живых» данных и обработка expired state (пустой iterator) смешаны в одном методе — код становится сложным и error-prone.
Ограничение 7: нет инициализации из внешнего источника. Нельзя предзаполнить state из существующего dataset’а при старте.
StatefulProcessor: архитектура нового API
В transformWithState пользователь реализует StatefulProcessor[K, I, O]:
// Параметры: K = тип ключа, I = тип входной строки, O = тип выходной строки
abstract class StatefulProcessor[K, I, O] extends Serializable {
// Доступ к handle — нельзя сохранять как field вне init/handle методов
protected def getHandle: StatefulProcessorHandle
// Инициализация: здесь объявляются state-переменные
def init(outputMode: OutputMode, timeMode: TimeMode): Unit
// Обработка входных строк для ключа
def handleInputRows(
key: K,
rows: Iterator[I],
timerValues: TimerValues
): Iterator[O]
// Обработка истёкших таймеров
def handleExpiredTimer(
key: K,
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo
): Iterator[O] = Iterator.empty // по умолчанию — ничего не делать
// Очистка ресурсов при остановке запроса
def close(): Unit = {}
}
Опциональный расширенный интерфейс StatefulProcessorWithInitialState[K, I, O, S] добавляет метод handleInitialState(key: K, initialState: S, timerValues: TimerValues) для предзаполнения state из batch dataset’а.
StatefulProcessorHandle: регистрация state-переменных
StatefulProcessorHandle — это контекст, через который StatefulProcessor получает доступ к state и таймерам. Он инжектируется системой и живёт в init():
trait StatefulProcessorHandle {
// Создание state-переменных
def getValueState[T](stateName: String, encoder: Encoder[T]): ValueState[T]
def getValueState[T](stateName: String, encoder: Encoder[T], ttlConfig: TTLConfig): ValueState[T]
def getListState[T](stateName: String, encoder: Encoder[T]): ListState[T]
def getListState[T](stateName: String, encoder: Encoder[T], ttlConfig: TTLConfig): ListState[T]
def getMapState[K, V](stateName: String, userKeyEncoder: Encoder[K], valueEncoder: Encoder[V]): MapState[K, V]
def getMapState[K, V](..., ttlConfig: TTLConfig): MapState[K, V]
// Работа с таймерами
def registerTimer(expiryTimestampMs: Long): Unit
def deleteTimer(expiryTimestampMs: Long): Unit
def listTimers(): Iterator[Long]
// Удаление state-переменной (для schema evolution)
def deleteIfExists(stateName: String): Unit
}
ValueState, ListState, MapState
ValueState[T] — одно значение на ключ:
trait ValueState[T] {
def exists(): Boolean // есть ли значение для текущего ключа?
def get(): T // прочитать значение
def update(value: T): Unit // записать значение
def clear(): Unit // удалить
}
ListState[T] — список значений на ключ, оптимизирован для append:
trait ListState[T] {
def exists(): Boolean
def get(): Iterator[T]
def put(newState: java.util.List[T]): Unit // заменить весь список
def appendValue(newState: T): Unit // добавить один элемент
def appendList(newState: java.util.List[T]): Unit // добавить несколько
def clear(): Unit
}
MapState[K, V] — вложенная map на ключ, оптимизирован для point lookups:
trait MapState[K, V] {
def exists(): Boolean
def getValue(key: K): V
def containsKey(key: K): Boolean
def updateValue(key: K, value: V): Unit
def iterator(): Iterator[java.util.Map.Entry[K, V]]
def keys(): Iterator[K]
def values(): Iterator[V]
def removeKey(key: K): Unit
def clear(): Unit
}
Три типа state-переменных с разной семантикой доступа. Каждая переменная — отдельный column family в RocksDB.
Пример: session detection с несколькими state-переменными
import org.apache.spark.sql.streaming.{StatefulProcessor, StatefulProcessorHandle,
OutputMode, TimeMode, TimerValues}
case class Event(userId: String, action: String, timestamp: Long)
case class SessionOutput(userId: String, sessionStart: Long, sessionEnd: Long, eventCount: Int)
class SessionDetector extends StatefulProcessor[String, Event, SessionOutput] {
@transient private var sessionStart: ValueState[Long] = _
@transient private var eventCount: ValueState[Int] = _
@transient private var lastEventTime: ValueState[Long] = _
val SESSION_TIMEOUT_MS = 30 * 60 * 1000L // 30 минут
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
// Регистрируем три независимых state-переменных
sessionStart = getHandle.getValueState[Long]("sessionStart", Encoders.scalaLong)
eventCount = getHandle.getValueState[Int]("eventCount", Encoders.scalaInt)
lastEventTime = getHandle.getValueState[Long]("lastEventTime", Encoders.scalaLong)
}
override def handleInputRows(
key: String,
rows: Iterator[Event],
timerValues: TimerValues
): Iterator[SessionOutput] = {
val events = rows.toList
val latestTs = events.map(_.timestamp).max
if (!sessionStart.exists()) {
// Новая сессия
sessionStart.update(events.map(_.timestamp).min)
eventCount.update(events.size)
} else {
// Продолжение сессии
eventCount.update(eventCount.get() + events.size)
}
lastEventTime.update(latestTs)
// Отменяем старый таймер (если был) и ставим новый
getHandle.listTimers().foreach(t => getHandle.deleteTimer(t))
getHandle.registerTimer(latestTs + SESSION_TIMEOUT_MS)
Iterator.empty // результат испускается при истечении таймера
}
override def handleExpiredTimer(
key: String,
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo
): Iterator[SessionOutput] = {
// Сессия завершилась — испускаем результат и очищаем state
val output = SessionOutput(
userId = key,
sessionStart = sessionStart.get(),
sessionEnd = lastEventTime.get(),
eventCount = eventCount.get()
)
sessionStart.clear()
eventCount.clear()
lastEventTime.clear()
Iterator(output)
}
}
// Использование
val sessionResults = events
.groupBy("userId")
.transformWithState(
new SessionDetector(),
OutputMode.Append(),
TimeMode.ProcessingTime()
)
Обратите внимание: три state-переменных объявлены в init(), а логика обработки и логика завершения — в разных методах. Это радикально чище, чем старый API.
Таймеры: Processing Time vs Event Time
transformWithState поддерживает два режима времени, задаваемых через TimeMode:
TimeMode.ProcessingTime() — таймеры срабатывают по wall-clock времени:
// Зарегистрировать таймер через 5 секунд от текущего processing time
getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000L)
TimeMode.EventTime() — таймеры срабатывают когда watermark продвигается за указанную метку:
// Таймер сработает, когда watermark достигнет eventTime + 1 час
getHandle.registerTimer(eventTime + 3600_000L)
// timerValues.getCurrentEventTimeInMs() — текущий watermark
TimeMode.None() — таймеры не поддерживаются, только pure state-трансформации.
Важное ограничение Processing Time таймеров: они не строго гарантированы на точное время. Если батч в ProcessingTime-режиме стартует каждые 10 секунд, таймер «через 5 секунд» сработает на следующем батче, а не ровно через 5 секунд.
// Python API (transformWithStateInPandas)
class DowntimeDetector(StatefulProcessor):
def init(self, handle):
state_schema = StructType([StructField("last_seen", TimestampType())])
self.last_seen = handle.getValueState("last_seen", state_schema)
def handleInputRows(self, key, rows, timerValues):
latest = max(rows, key=lambda r: r["timestamp"])
self.last_seen.update((latest["timestamp"],))
# Удалить старые таймеры и поставить новый
for t in self.handle.listTimers():
self.handle.deleteTimer(t)
self.handle.registerTimer(
timerValues.getCurrentProcessingTimeInMs() + 5_000)
return iter([])
def handleExpiredTimer(self, key, timerValues, expiredTimerInfo):
last = self.last_seen.get()
downtime_ms = timerValues.getCurrentProcessingTimeInMs() - \
int(last[0].timestamp() * 1000)
yield {"device_id": key[0], "downtime_seconds": downtime_ms // 1000}
TTL: автоматическое удаление просроченного state
TTL (Time-To-Live) — механизм автоматического удаления state-переменных по истечении заданного времени, без ручной логики в handleExpiredTimer:
import org.apache.spark.sql.streaming.TTLConfig
import java.time.Duration
// ValueState с TTL 24 часа
val recentEvents = getHandle.getValueState[Long](
"recentEvents",
Encoders.scalaLong,
TTLConfig(Duration.ofHours(24)) // автоудаление через 24 часа неактивности
)
TTL работает в двух режимах:
- Processing time TTL: state удаляется через N миллисекунд processing time после последнего update
- Event time TTL: state удаляется когда watermark продвигается за
(lastUpdateEventTime + ttl)
TTL реализован через secondary index в StateStore: при каждом update() записывается пара (expiryTime, key). TransformWithStateExec вызывает doTtlCleanup() после каждого батча, который сканирует secondary index и удаляет expired entries.
Используйте TTL вместо таймеров для простых сценариев «удалить state через N времени после последнего обновления». Таймеры мощнее (кастомная логика при срабатывании), но TTL — дешевле (нет overhead на регистрацию/удаление таймеров) и проще в коде.
TransformWithStateExec: физическое исполнение
TransformWithStateExec — физический оператор, реализующий transformWithState. Его ключевые фазы в doExecute():
Per-partition execution:
1. Инициализация StateStore для этой партиции
2. Вызов statefulProcessor.init(outputMode, timeMode)
3. Если есть initialState DataFrame и currentBatchId == 0:
-> группировать initialState по ключу
-> вызвать statefulProcessor.handleInitialState(key, state, timerValues)
4. Для каждой группы входных строк (группировка по ключу):
-> вызвать statefulProcessor.handleInputRows(key, rows, timerValues)
-> собрать Iterator[O] в выходной буфер
5. Для каждого истёкшего таймера (ordered by expiry):
-> вызвать statefulProcessor.handleExpiredTimer(key, timerValues, info)
6. doTtlCleanup(): удалить expired state из secondary index
7. store.commit(): записать delta в checkpoint
Критическое отличие от FlatMapGroupsWithStateExec: TransformWithStateExec обрабатывает истёкшие таймеры в том же батче что и входные данные, в детерминированном порядке (сначала input, потом expired timers). В старом API expired state определялся по пустому iterator — это был хак, а не явный механизм.
Четыре последовательные фазы в одном батче. Таймеры обрабатываются после входных данных.
Batch mode: isStreaming = false
В Spark 4.0 TransformWithStateExec поддерживает batch mode через флаг isStreaming. Это позволяет использовать тот же StatefulProcessor для одноразовой обработки статических данных (например, для тестирования или исторических расчётов):
# Batch mode: обрабатываем статический DataFrame
batch_df = spark.read.parquet("/data/events")
result = batch_df \
.groupBy("userId") \
.transformWithState(
statefulProcessor=MyProcessor(),
outputMode=OutputMode.Append,
timeMode=TimeMode.None
)
result.write.parquet("/output")
В batch mode TransformWithStateExec не создаёт checkpoint и не восстанавливает state — это однократное применение логики. Ценно для тестирования: можно написать StatefulProcessor и отладить его на batch data перед запуском как streaming.
Schema evolution: добавление state-переменных
Одно из ключевых преимуществ transformWithState перед старым API — поддержка schema evolution без полного перезапуска:
// Версия 1: только sessionStart
override def init(...): Unit = {
sessionStart = getHandle.getValueState[Long]("sessionStart", Encoders.scalaLong)
}
// Версия 2: добавляем eventCount (безопасно!)
override def init(...): Unit = {
sessionStart = getHandle.getValueState[Long]("sessionStart", Encoders.scalaLong)
eventCount = getHandle.getValueState[Int]("eventCount", Encoders.scalaInt)
// eventCount.exists() == false для ключей, где он не был записан
// это ок: если eventCount.exists() == false, считаем 0
}
// Версия 3: удаляем obsolete state-переменную
override def init(...): Unit = {
getHandle.deleteIfExists("legacyField") // удаляет из StateStore
sessionStart = getHandle.getValueState[Long]("sessionStart", Encoders.scalaLong)
eventCount = getHandle.getValueState[Int]("eventCount", Encoders.scalaInt)
}
Для evolution внутри state-переменной (изменение схемы типа T) нужно Avro-кодирование:
spark.conf.set("spark.sql.streaming.stateStore.encodingFormat", "avro")
# Поддерживает: add field, remove field, type widening, reorder fields
# Не поддерживает: rename field, type narrowing
Сравнение transformWithState vs mapGroupsWithState
| Аспект | mapGroupsWithState | transformWithState |
|---|---|---|
| Типы state | Один объект S | ValueState, ListState, MapState |
| Таймеры | Один на ключ | Множество на ключ |
| Expired state | Пустой iterator (хак) | Отдельный handleExpiredTimer |
| Schema evolution | Требует restart | deleteIfExists, Avro |
| Chaining | Запрещён | Разрешён (SPARK-47960) |
| Side outputs | Нет | Через multiple outputs (roadmap) |
| Инициализация | Нет | handleInitialState |
| Batch mode | Нет | Да (isStreaming=false) |
| Языки | Scala, Java | Scala, Java, Python (Pandas), R |
Попробуй сам
Реализуйте детектор аномалий с несколькими state-переменными и таймером:
from pyspark.sql import SparkSession
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, StringType, LongType, DoubleType
import pandas as pd, time
spark = SparkSession.builder \
.appName("transform-with-state-demo") \
.config("spark.sql.streaming.stateStore.providerClass",
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") \
.getOrCreate()
output_schema = StructType([
StructField("sensor_id", StringType()),
StructField("alert", StringType()),
StructField("avg_value", DoubleType()),
StructField("count", LongType())
])
class AnomalyDetector(StatefulProcessor):
"""Детектирует аномалии: если среднее значение > 2 * baseline за последние N событий."""
WINDOW_SIZE = 10
ALERT_TIMEOUT_MS = 60_000 # 60 секунд без данных -> alert
def init(self, handle: StatefulProcessorHandle):
value_schema = StructType([StructField("v", DoubleType())])
count_schema = StructType([StructField("v", LongType())])
self.handle = handle
self.baseline = handle.getValueState("baseline", value_schema)
self.sum_values = handle.getValueState("sum_values", value_schema)
self.event_count = handle.getValueState("event_count", count_schema)
def handleInputRows(self, key, rows, timerValues):
for pdf in rows:
for _, row in pdf.iterrows():
val = float(row["value"])
if not self.event_count.exists():
self.event_count.update((0,))
self.sum_values.update((0.0,))
self.baseline.update((val,))
cnt = int(self.event_count.get()[0]) + 1
s = float(self.sum_values.get()[0]) + val
self.event_count.update((cnt,))
self.sum_values.update((s,))
avg = s / cnt
# Обновляем таймер "нет данных"
for t in self.handle.listTimers():
self.handle.deleteTimer(t)
self.handle.registerTimer(
timerValues.getCurrentProcessingTimeInMs() + self.ALERT_TIMEOUT_MS)
baseline = float(self.baseline.get()[0])
if cnt > 5 and avg > 2 * baseline:
yield pd.DataFrame({
"sensor_id": [key[0]],
"alert": ["HIGH_VALUE"],
"avg_value": [avg],
"count": [cnt]
})
return
def handleExpiredTimer(self, key, timerValues, expiredTimerInfo):
yield pd.DataFrame({
"sensor_id": [key[0]],
"alert": ["NO_DATA"],
"avg_value": [0.0],
"count": [0]
})
self.baseline.clear()
self.sum_values.clear()
self.event_count.clear()
def close(self):
pass
# Генерируем данные
df = spark.readStream.format("rate").option("rowsPerSecond", 10).load()
from pyspark.sql.functions import col, (col("value") % 5).alias("sensor_id"), \
(col("value").cast("double") * 1.5).alias("value")
sensors = df.select(
(col("value") % 5).cast("string").alias("sensor_id"),
(col("value").cast("double")).alias("value"),
col("timestamp")
)
alerts = sensors \
.groupBy("sensor_id") \
.transformWithStateInPandas(
statefulProcessor=AnomalyDetector(),
outputStructType=output_schema,
outputMode="update",
timeMode="processingTime"
)
query = alerts.writeStream \
.outputMode("update") \
.format("console") \
.trigger(processingTime="5 seconds") \
.start()
time.sleep(30)
query.stop()
Ожидаемый вывод через 15-20 секунд:
+----------+-----------+---------+-----+
|sensor_id |alert |avg_value|count|
+----------+-----------+---------+-----+
|3 |HIGH_VALUE |45.2 |8 |
|1 |HIGH_VALUE |37.8 |9 |
+----------+-----------+---------+-----+