Unit-тестирование PySpark-приложений
Зачем тестировать Spark-код?
Data pipeline без тестов — это бомба замедленного действия. Ошибка в трансформации может:
- Испортить данные тихо — агрегаты считаются неправильно, но pipeline не падает
- Обнаружиться через дни — когда бизнес-отчёт покажет аномальные цифры
- Быть дорогой — перезапуск pipeline на терабайтах данных занимает часы и стоит денег
Unit-тесты ловят ошибки до production. Каждая трансформация тестируется изолированно, на малых данных, за секунды.
Настройка pytest для PySpark
conftest.py с SparkSession fixture
Ключевой паттерн — conftest.py с session-scoped SparkSession. Один SparkSession на весь тестовый прогон (JVM запускается один раз):
# tests/conftest.py
import pytest
from pyspark.sql import SparkSession
@pytest.fixture(scope="session")
def spark():
"""SparkSession для тестов -- один экземпляр на весь прогон."""
session = SparkSession.builder \
.master("local[2]") \
.appName("unit-tests") \
.config("spark.sql.shuffle.partitions", "2") \
.config("spark.default.parallelism", "2") \
.config("spark.ui.enabled", "false") \
.config("spark.driver.bindAddress", "127.0.0.1") \
.getOrCreate()
yield session
session.stop()
Важные настройки для тестов:
| Конфигурация | Значение | Зачем |
|---|---|---|
master("local[2]") | 2 потока | Минимум для параллелизма; local[1] скрывает баги race condition |
shuffle.partitions | 2 | По умолчанию 200 — слишком много для тестовых данных |
spark.ui.enabled | false | Отключает Web UI (ускоряет запуск) |
Структура тестового проекта
my_spark_project/
├── src/
│ └── transformations/
│ ├── __init__.py
│ ├── clean.py # чистые функции-трансформации
│ └── aggregate.py
├── tests/
│ ├── conftest.py # SparkSession fixture
│ ├── test_clean.py
│ └── test_aggregate.py
└── pyproject.toml
Паттерн: трансформации как чистые функции
Главный принцип тестируемого Spark-кода — трансформации как чистые функции. Функция принимает DataFrame и возвращает DataFrame:
# src/transformations/clean.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, trim, lower, when
def clean_emails(df: DataFrame) -> DataFrame:
"""Нормализует email: trim + lowercase, невалидные -> NULL."""
return df.withColumn(
"email",
when(
col("email").rlike(r"^[\w.+-]+@[\w-]+\.[\w.]+$"),
lower(trim(col("email")))
)
)
def filter_active_users(df: DataFrame, min_orders: int = 1) -> DataFrame:
"""Оставляет пользователей с >= min_orders заказами."""
return df.filter(col("order_count") >= min_orders)
Такие функции легко тестировать — создаём маленький DataFrame на входе, проверяем результат:
# tests/test_clean.py
from src.transformations.clean import clean_emails, filter_active_users
def test_clean_emails_normalizes_case(spark):
data = [("1", " [email protected] "), ("2", "[email protected]")]
df = spark.createDataFrame(data, ["id", "email"])
result = clean_emails(df)
emails = [row.email for row in result.collect()]
assert emails == ["[email protected]", "[email protected]"]
def test_clean_emails_nullifies_invalid(spark):
data = [("1", "not-an-email"), ("2", "[email protected]")]
df = spark.createDataFrame(data, ["id", "email"])
result = clean_emails(df)
emails = [row.email for row in result.collect()]
assert emails[0] is None
assert emails[1] == "[email protected]"
def test_filter_active_users(spark):
data = [("alice", 5), ("bob", 0), ("carol", 1)]
df = spark.createDataFrame(data, ["name", "order_count"])
result = filter_active_users(df, min_orders=1)
assert result.count() == 2
names = {row.name for row in result.collect()}
assert names == {"alice", "carol"}
DataFrame assertions
assertDataFrameEqual (PySpark 3.5+)
Spark3.5Начиная с PySpark 3.5 доступна встроенная функция assertDataFrameEqual:
from pyspark.testing.utils import assertDataFrameEqual
def test_aggregation_with_assert_df_equal(spark):
input_data = [("Moscow", 100), ("Moscow", 200), ("SPb", 150)]
df = spark.createDataFrame(input_data, ["city", "amount"])
result = df.groupBy("city").sum("amount")
expected = spark.createDataFrame(
[("Moscow", 300), ("SPb", 150)],
["city", "sum(amount)"]
)
# Порядок строк не важен -- сравнивает как множества
assertDataFrameEqual(result, expected)
assertDataFrameEqual сравнивает:
- Схему (имена и типы колонок)
- Данные (значения строк, без учёта порядка)
- NULL значения корректно
Custom assertion helpers (до PySpark 3.5)
Для более ранних версий — вспомогательные функции:
def assert_schema_equal(df, expected_columns: dict):
"""Проверяет имена и типы колонок."""
actual = {f.name: str(f.dataType) for f in df.schema.fields}
assert actual == expected_columns
def assert_row_count(df, expected: int):
"""Проверяет количество строк."""
actual = df.count()
assert actual == expected, f"Expected {expected} rows, got {actual}"
def assert_no_nulls(df, columns: list):
"""Проверяет отсутствие NULL в указанных колонках."""
for col_name in columns:
null_count = df.filter(df[col_name].isNull()).count()
assert null_count == 0, f"Column '{col_name}' has {null_count} NULLs"
Mock внешних источников
В unit-тестах мы не читаем из S3/HDFS/Kafka. Вместо этого:
Паттерн 1: локальные fixture-файлы
# tests/fixtures/users_sample.csv -- маленький CSV для тестов
def test_pipeline_with_csv_fixture(spark, tmp_path):
# Создаём fixture данные
data = [("1", "alice", "[email protected]"), ("2", "bob", "[email protected]")]
fixture_df = spark.createDataFrame(data, ["id", "name", "email"])
fixture_path = str(tmp_path / "users.parquet")
fixture_df.write.parquet(fixture_path)
# Тестируемый код читает из fixture вместо S3
df = spark.read.parquet(fixture_path)
result = clean_emails(df)
assert result.count() == 2
Паттерн 2: параметризованная функция-reader
# src/pipeline.py
def run_pipeline(spark, read_fn):
"""Pipeline принимает функцию чтения -- легко подменить в тестах."""
raw = read_fn(spark)
cleaned = clean_emails(raw)
return cleaned
# Production:
# run_pipeline(spark, lambda s: s.read.parquet("s3a://bucket/users"))
# Test:
# run_pipeline(spark, lambda s: s.createDataFrame(test_data, schema))
Паттерн 3: pytest monkeypatch
def test_pipeline_with_monkeypatch(spark, monkeypatch):
test_data = [("1", "[email protected]")]
test_df = spark.createDataFrame(test_data, ["id", "email"])
# Подменяем функцию чтения
monkeypatch.setattr(
"src.pipeline.read_users",
lambda s: test_df
)
from src.pipeline import process_users
result = process_users(spark)
assert result.count() == 1
Запуск тестов
# Все тесты
pytest tests/ -v
# Только unit-тесты (без интеграционных)
pytest tests/ -v -m "not integration"
# С отчётом покрытия
pytest tests/ --cov=src --cov-report=html
Совет: добавьте conftest.py в корень tests/ — pytest автоматически найдёт fixture spark во всех вложенных тестовых файлах. Не нужно импортировать fixture явно.
Что дальше?
В следующем уроке разберём интеграционное тестирование — как тестировать Spark-код с реальными внешними системами через Testcontainers и pytest markers.