Learning Platform
Глоссарий Troubleshooting
Урок 06.02 · 20 мин
Продвинутый
AccumulatorUDAFUDWFPartitionEvaluatorcreate_udafcreate_udwfAggregateUDFImpl

Агрегатные и оконные UDF

В модуле 04 мы создавали агрегатные функции на Python через класс Accumulator. Rust API предоставляет trait AggregateUDFImpl с полным контролем над жизненным циклом агрегации — инициализация состояния, инкрементальное обновление, слияние партиций и финальное вычисление.

UDAF: пользовательские агрегатные функции

Trait Accumulator

Accumulator — ядро агрегатной функции. Он управляет состоянием между батчами строк:

Accumulator: жизненный цикл
Создание AccumulatorDataFusion создаёт новый экземпляр Accumulator для каждой партиции данных
Partition 1
update_batch(values)Первый батч строк обновляет внутреннее состояние аккумулятора — суммы, счётчики, коллекции
Partition 2
update_batch(values)Каждый следующий батч инкрементально обновляет состояние без пересоздания аккумулятора
Слияние партиций
merge_batch(states)Промежуточные состояния из параллельных партиций объединяются в единый аккумулятор
Финал
evaluate() → ScalarValueФинальное вычисление преобразует внутреннее состояние в одно скалярное значение результата

При параллельном выполнении DataFusion создаёт отдельный Accumulator для каждой партиции. Метод state() сериализует промежуточное состояние, а merge_batch() объединяет состояния из разных партиций.

Полная реализация: средневзвешенное

use datafusion::logical_expr::Accumulator;
use datafusion::arrow::array::{ArrayRef, Float64Array};
use datafusion::common::{Result, ScalarValue};

#[derive(Debug)]
struct WeightedAvgAccumulator {
    sum_product: f64,
    sum_weights: f64,
}

impl WeightedAvgAccumulator {
    fn new() -> Self {
        Self { sum_product: 0.0, sum_weights: 0.0 }
    }
}

impl Accumulator for WeightedAvgAccumulator {
    fn state(&self) -> Result<Vec<ScalarValue>> {
        // Промежуточное состояние — два числа
        Ok(vec![
            ScalarValue::Float64(Some(self.sum_product)),
            ScalarValue::Float64(Some(self.sum_weights)),
        ])
    }

    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        // values[0] = значения, values[1] = веса
        let vals = values[0].as_any().downcast_ref::<Float64Array>().unwrap();
        let weights = values[1].as_any().downcast_ref::<Float64Array>().unwrap();

        for i in 0..vals.len() {
            if !vals.is_null(i) && !weights.is_null(i) {
                let v = vals.value(i);
                let w = weights.value(i);
                self.sum_product += v * w;
                self.sum_weights += w;
            }
        }
        Ok(())
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        // Слияние состояний из параллельных партиций
        let products = states[0].as_any().downcast_ref::<Float64Array>().unwrap();
        let weights = states[1].as_any().downcast_ref::<Float64Array>().unwrap();

        for i in 0..products.len() {
            if !products.is_null(i) {
                self.sum_product += products.value(i);
            }
            if !weights.is_null(i) {
                self.sum_weights += weights.value(i);
            }
        }
        Ok(())
    }

    fn evaluate(&self) -> Result<ScalarValue> {
        if self.sum_weights == 0.0 {
            Ok(ScalarValue::Float64(None))
        } else {
            Ok(ScalarValue::Float64(Some(self.sum_product / self.sum_weights)))
        }
    }

    fn size(&self) -> usize {
        std::mem::size_of::<Self>()
    }
}
NOTE

Метод size() возвращает потребление памяти аккумулятором. DataFusion использует его для контроля memory budget при параллельных агрегациях с большим количеством групп.

Trait AggregateUDFImpl

AggregateUDFImpl описывает метаданные агрегатной функции и создаёт Accumulator:

use datafusion::logical_expr::{AggregateUDFImpl, Signature, Volatility};
use datafusion::arrow::datatypes::{DataType, Field};
use std::any::Any;

#[derive(Debug)]
struct WeightedAvgUdaf {
    signature: Signature,
}

impl WeightedAvgUdaf {
    fn new() -> Self {
        Self {
            signature: Signature::exact(
                vec![DataType::Float64, DataType::Float64],
                Volatility::Immutable,
            ),
        }
    }
}

impl AggregateUDFImpl for WeightedAvgUdaf {
    fn as_any(&self) -> &dyn Any { self }
    fn name(&self) -> &str { "weighted_avg" }
    fn signature(&self) -> &Signature { &self.signature }

    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(DataType::Float64)
    }

    fn accumulator(&self, _acc_args: &datafusion::logical_expr::AccumulatorArgs)
        -> Result<Box<dyn Accumulator>>
    {
        Ok(Box::new(WeightedAvgAccumulator::new()))
    }

    fn state_fields(&self, _args: datafusion::logical_expr::StateFieldsArgs)
        -> Result<Vec<Field>>
    {
        // Описание полей промежуточного состояния
        Ok(vec![
            Field::new("sum_product", DataType::Float64, true),
            Field::new("sum_weights", DataType::Float64, true),
        ])
    }
}

Регистрация UDAF

use datafusion::logical_expr::AggregateUDF;
use datafusion::prelude::SessionContext;

let udaf = AggregateUDF::new_from_impl(WeightedAvgUdaf::new());
let ctx = SessionContext::new();
ctx.register_udaf(udaf);

// Использование в SQL
let df = ctx.sql(
    "SELECT department, weighted_avg(salary, experience) FROM employees GROUP BY department"
).await?;
df.show().await?;

create_udaf: быстрый способ

Для простых случаев без реализации trait:

use datafusion::logical_expr::create_udaf;

let sum_udaf = create_udaf(
    "custom_sum",
    vec![DataType::Float64],        // Типы аргументов
    Arc::new(DataType::Float64),    // Тип результата
    Volatility::Immutable,
    Arc::new(|_| Ok(Box::new(SumAccumulator::new()))),  // Фабрика аккумулятора
    Arc::new(vec![DataType::Float64]),  // Типы состояния
);
WARNING

Функции create_udaf и create_udwf считаются устаревшими начиная с DataFusion 42.x. Для новых проектов используйте AggregateUDF::new_from_impl(impl AggregateUDFImpl) для агрегатных функций и WindowUDF::new_from_impl(impl WindowUDFImpl) для оконных. Convenience-функции сохранены только для обратной совместимости.

UDWF: пользовательские оконные функции

Оконные функции вычисляют значения в контексте окна строк (OVER clause). Rust API использует trait PartitionEvaluator.

Trait PartitionEvaluator

use datafusion::logical_expr::PartitionEvaluator;
use datafusion::arrow::array::{ArrayRef, Float64Array};
use datafusion::common::Result;
use std::sync::Arc;

#[derive(Debug)]
struct ExponentialMovingAvg {
    alpha: f64,
}

impl ExponentialMovingAvg {
    fn new(alpha: f64) -> Self {
        Self { alpha }
    }
}

impl PartitionEvaluator for ExponentialMovingAvg {
    fn evaluate_all(
        &mut self,
        values: &[ArrayRef],
        num_rows: usize,
    ) -> Result<ArrayRef> {
        let input = values[0].as_any()
            .downcast_ref::<Float64Array>()
            .unwrap();

        let mut ema = Vec::with_capacity(num_rows);
        let mut prev = 0.0;

        for i in 0..num_rows {
            if input.is_null(i) {
                ema.push(None);
            } else {
                let val = input.value(i);
                let current = if i == 0 {
                    val
                } else {
                    self.alpha * val + (1.0 - self.alpha) * prev
                };
                prev = current;
                ema.push(Some(current));
            }
        }

        Ok(Arc::new(Float64Array::from(ema)) as ArrayRef)
    }

    fn uses_window_frame(&self) -> bool {
        false  // EMA обрабатывает всю партицию, не зависит от ROWS BETWEEN
    }

    fn include_rank(&self) -> bool {
        false  // Не нужна информация о ранге
    }
}
TIP

evaluate_all вызывается один раз для всей партиции окна. Если ваша функция зависит от оконного фрейма (ROWS BETWEEN), верните true из uses_window_frame() и реализуйте evaluate вместо evaluate_all.

Trait WindowUDFImpl и регистрация

use datafusion::logical_expr::{WindowUDFImpl, Signature, Volatility, WindowUDF};
use datafusion::arrow::datatypes::{DataType, Field};

#[derive(Debug)]
struct EmaUdwf {
    signature: Signature,
}

impl EmaUdwf {
    fn new() -> Self {
        Self {
            signature: Signature::exact(
                vec![DataType::Float64],
                Volatility::Immutable,
            ),
        }
    }
}

impl WindowUDFImpl for EmaUdwf {
    fn as_any(&self) -> &dyn Any { self }
    fn name(&self) -> &str { "ema" }
    fn signature(&self) -> &Signature { &self.signature }

    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(DataType::Float64)
    }

    fn partition_evaluator(
        &self,
        _partition_evaluator_args: datafusion::logical_expr::PartitionEvaluatorArgs,
    ) -> Result<Box<dyn PartitionEvaluator>> {
        Ok(Box::new(ExponentialMovingAvg::new(0.3)))
    }

    fn field(&self, field_args: datafusion::logical_expr::WindowUDFFieldArgs)
        -> Result<Field>
    {
        Ok(Field::new(field_args.name(), DataType::Float64, true))
    }
}

// Регистрация
let udwf = WindowUDF::new_from_impl(EmaUdwf::new());
let ctx = SessionContext::new();
ctx.register_udwf(udwf);

let df = ctx.sql(
    "SELECT ts, price, ema(price) OVER (ORDER BY ts) FROM stock_prices"
).await?;

Сравнение: UDF vs UDAF vs UDWF

Типы пользовательских функций
Scalar UDFСкалярная функция: вычисляет одно значение для каждой входной строки — map-операция над массивами
Aggregate UDAFАгрегатная функция: сводит N строк к одному значению через Accumulator с поддержкой параллельного слияния
Window UDWFОконная функция: вычисляет значение для каждой строки в контексте окна через PartitionEvaluator

Полный пример: медианная агрегация

use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, AggregateUDF, Signature, Volatility};
use datafusion::arrow::array::{ArrayRef, Float64Array};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{Result, ScalarValue};
use std::any::Any;

#[derive(Debug)]
struct MedianAccumulator {
    values: Vec<f64>,
}

impl Accumulator for MedianAccumulator {
    fn state(&self) -> Result<Vec<ScalarValue>> {
        // Сериализуем все значения как List
        Ok(vec![ScalarValue::List(
            ScalarValue::new_list(
                &self.values.iter().map(|v| ScalarValue::Float64(Some(*v))).collect::<Vec<_>>(),
                &DataType::Float64,
            )
        )])
    }

    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        let array = values[0].as_any().downcast_ref::<Float64Array>().unwrap();
        for i in 0..array.len() {
            if !array.is_null(i) {
                self.values.push(array.value(i));
            }
        }
        Ok(())
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        // Разворачиваем List-состояния из других партиций
        let list_array = &states[0];
        // Упрощённо: итерируем по вложенным значениям
        let values = list_array.as_any().downcast_ref::<Float64Array>();
        if let Some(arr) = values {
            for i in 0..arr.len() {
                if !arr.is_null(i) {
                    self.values.push(arr.value(i));
                }
            }
        }
        Ok(())
    }

    fn evaluate(&self) -> Result<ScalarValue> {
        if self.values.is_empty() {
            return Ok(ScalarValue::Float64(None));
        }
        let mut sorted = self.values.clone();
        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
        let mid = sorted.len() / 2;
        let median = if sorted.len() % 2 == 0 {
            (sorted[mid - 1] + sorted[mid]) / 2.0
        } else {
            sorted[mid]
        };
        Ok(ScalarValue::Float64(Some(median)))
    }

    fn size(&self) -> usize {
        std::mem::size_of::<Self>() + self.values.capacity() * std::mem::size_of::<f64>()
    }
}

#[derive(Debug)]
struct MedianUdaf {
    signature: Signature,
}

impl AggregateUDFImpl for MedianUdaf {
    fn as_any(&self) -> &dyn Any { self }
    fn name(&self) -> &str { "median" }
    fn signature(&self) -> &Signature { &self.signature }
    fn return_type(&self, _: &[DataType]) -> Result<DataType> { Ok(DataType::Float64) }
    fn accumulator(&self, _: &datafusion::logical_expr::AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
        Ok(Box::new(MedianAccumulator { values: vec![] }))
    }
    fn state_fields(&self, _: datafusion::logical_expr::StateFieldsArgs) -> Result<Vec<Field>> {
        Ok(vec![Field::new("values", DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), true)])
    }
}
WARNING

Медианная агрегация хранит все значения в памяти. Для больших датасетов это может вызвать OOM. В production используйте approximate median (t-digest) или ограничивайте размер через size().

Итоги

  • Accumulator trait управляет жизненным циклом агрегации: stateupdate_batchmerge_batchevaluate
  • AggregateUDFImpl + Accumulator — полный контроль; create_udaf — быстрый способ
  • state_fields() описывает промежуточное состояние для сериализации между партициями
  • PartitionEvaluator для UDWF обрабатывает строки в контексте оконного фрейма
  • WindowUDFImpl + PartitionEvaluator — оконные функции с evaluate_all или evaluate
  • size() в Accumulator — контроль памяти при параллельных агрегациях с большим числом групп

Проверьте понимание

Результат: 0 из 0
Концептуальный
Вопрос 1 из 5. В каком порядке DataFusion вызывает методы Accumulator при параллельной агрегации по нескольким партициям?

Закончили урок?

Отметьте его как пройденный, чтобы отслеживать свой прогресс

Войдите чтобы оценить урок

Прогресс модуля
0 из 8