Агрегатные и оконные UDF
В модуле 04 мы создавали агрегатные функции на Python через класс Accumulator. Rust API предоставляет trait AggregateUDFImpl с полным контролем над жизненным циклом агрегации — инициализация состояния, инкрементальное обновление, слияние партиций и финальное вычисление.
UDAF: пользовательские агрегатные функции
Trait Accumulator
Accumulator — ядро агрегатной функции. Он управляет состоянием между батчами строк:
При параллельном выполнении DataFusion создаёт отдельный Accumulator для каждой партиции. Метод state() сериализует промежуточное состояние, а merge_batch() объединяет состояния из разных партиций.
Полная реализация: средневзвешенное
use datafusion::logical_expr::Accumulator;
use datafusion::arrow::array::{ArrayRef, Float64Array};
use datafusion::common::{Result, ScalarValue};
#[derive(Debug)]
struct WeightedAvgAccumulator {
sum_product: f64,
sum_weights: f64,
}
impl WeightedAvgAccumulator {
fn new() -> Self {
Self { sum_product: 0.0, sum_weights: 0.0 }
}
}
impl Accumulator for WeightedAvgAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
// Промежуточное состояние — два числа
Ok(vec![
ScalarValue::Float64(Some(self.sum_product)),
ScalarValue::Float64(Some(self.sum_weights)),
])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
// values[0] = значения, values[1] = веса
let vals = values[0].as_any().downcast_ref::<Float64Array>().unwrap();
let weights = values[1].as_any().downcast_ref::<Float64Array>().unwrap();
for i in 0..vals.len() {
if !vals.is_null(i) && !weights.is_null(i) {
let v = vals.value(i);
let w = weights.value(i);
self.sum_product += v * w;
self.sum_weights += w;
}
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
// Слияние состояний из параллельных партиций
let products = states[0].as_any().downcast_ref::<Float64Array>().unwrap();
let weights = states[1].as_any().downcast_ref::<Float64Array>().unwrap();
for i in 0..products.len() {
if !products.is_null(i) {
self.sum_product += products.value(i);
}
if !weights.is_null(i) {
self.sum_weights += weights.value(i);
}
}
Ok(())
}
fn evaluate(&self) -> Result<ScalarValue> {
if self.sum_weights == 0.0 {
Ok(ScalarValue::Float64(None))
} else {
Ok(ScalarValue::Float64(Some(self.sum_product / self.sum_weights)))
}
}
fn size(&self) -> usize {
std::mem::size_of::<Self>()
}
}
Метод size() возвращает потребление памяти аккумулятором. DataFusion использует его для контроля memory budget при параллельных агрегациях с большим количеством групп.
Trait AggregateUDFImpl
AggregateUDFImpl описывает метаданные агрегатной функции и создаёт Accumulator:
use datafusion::logical_expr::{AggregateUDFImpl, Signature, Volatility};
use datafusion::arrow::datatypes::{DataType, Field};
use std::any::Any;
#[derive(Debug)]
struct WeightedAvgUdaf {
signature: Signature,
}
impl WeightedAvgUdaf {
fn new() -> Self {
Self {
signature: Signature::exact(
vec![DataType::Float64, DataType::Float64],
Volatility::Immutable,
),
}
}
}
impl AggregateUDFImpl for WeightedAvgUdaf {
fn as_any(&self) -> &dyn Any { self }
fn name(&self) -> &str { "weighted_avg" }
fn signature(&self) -> &Signature { &self.signature }
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}
fn accumulator(&self, _acc_args: &datafusion::logical_expr::AccumulatorArgs)
-> Result<Box<dyn Accumulator>>
{
Ok(Box::new(WeightedAvgAccumulator::new()))
}
fn state_fields(&self, _args: datafusion::logical_expr::StateFieldsArgs)
-> Result<Vec<Field>>
{
// Описание полей промежуточного состояния
Ok(vec![
Field::new("sum_product", DataType::Float64, true),
Field::new("sum_weights", DataType::Float64, true),
])
}
}
Регистрация UDAF
use datafusion::logical_expr::AggregateUDF;
use datafusion::prelude::SessionContext;
let udaf = AggregateUDF::new_from_impl(WeightedAvgUdaf::new());
let ctx = SessionContext::new();
ctx.register_udaf(udaf);
// Использование в SQL
let df = ctx.sql(
"SELECT department, weighted_avg(salary, experience) FROM employees GROUP BY department"
).await?;
df.show().await?;
create_udaf: быстрый способ
Для простых случаев без реализации trait:
use datafusion::logical_expr::create_udaf;
let sum_udaf = create_udaf(
"custom_sum",
vec![DataType::Float64], // Типы аргументов
Arc::new(DataType::Float64), // Тип результата
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(SumAccumulator::new()))), // Фабрика аккумулятора
Arc::new(vec![DataType::Float64]), // Типы состояния
);
Функции create_udaf и create_udwf считаются устаревшими начиная с DataFusion 42.x. Для новых проектов используйте AggregateUDF::new_from_impl(impl AggregateUDFImpl) для агрегатных функций и WindowUDF::new_from_impl(impl WindowUDFImpl) для оконных. Convenience-функции сохранены только для обратной совместимости.
UDWF: пользовательские оконные функции
Оконные функции вычисляют значения в контексте окна строк (OVER clause). Rust API использует trait PartitionEvaluator.
Trait PartitionEvaluator
use datafusion::logical_expr::PartitionEvaluator;
use datafusion::arrow::array::{ArrayRef, Float64Array};
use datafusion::common::Result;
use std::sync::Arc;
#[derive(Debug)]
struct ExponentialMovingAvg {
alpha: f64,
}
impl ExponentialMovingAvg {
fn new(alpha: f64) -> Self {
Self { alpha }
}
}
impl PartitionEvaluator for ExponentialMovingAvg {
fn evaluate_all(
&mut self,
values: &[ArrayRef],
num_rows: usize,
) -> Result<ArrayRef> {
let input = values[0].as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let mut ema = Vec::with_capacity(num_rows);
let mut prev = 0.0;
for i in 0..num_rows {
if input.is_null(i) {
ema.push(None);
} else {
let val = input.value(i);
let current = if i == 0 {
val
} else {
self.alpha * val + (1.0 - self.alpha) * prev
};
prev = current;
ema.push(Some(current));
}
}
Ok(Arc::new(Float64Array::from(ema)) as ArrayRef)
}
fn uses_window_frame(&self) -> bool {
false // EMA обрабатывает всю партицию, не зависит от ROWS BETWEEN
}
fn include_rank(&self) -> bool {
false // Не нужна информация о ранге
}
}
evaluate_all вызывается один раз для всей партиции окна. Если ваша функция зависит от оконного фрейма (ROWS BETWEEN), верните true из uses_window_frame() и реализуйте evaluate вместо evaluate_all.
Trait WindowUDFImpl и регистрация
use datafusion::logical_expr::{WindowUDFImpl, Signature, Volatility, WindowUDF};
use datafusion::arrow::datatypes::{DataType, Field};
#[derive(Debug)]
struct EmaUdwf {
signature: Signature,
}
impl EmaUdwf {
fn new() -> Self {
Self {
signature: Signature::exact(
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}
impl WindowUDFImpl for EmaUdwf {
fn as_any(&self) -> &dyn Any { self }
fn name(&self) -> &str { "ema" }
fn signature(&self) -> &Signature { &self.signature }
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}
fn partition_evaluator(
&self,
_partition_evaluator_args: datafusion::logical_expr::PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(ExponentialMovingAvg::new(0.3)))
}
fn field(&self, field_args: datafusion::logical_expr::WindowUDFFieldArgs)
-> Result<Field>
{
Ok(Field::new(field_args.name(), DataType::Float64, true))
}
}
// Регистрация
let udwf = WindowUDF::new_from_impl(EmaUdwf::new());
let ctx = SessionContext::new();
ctx.register_udwf(udwf);
let df = ctx.sql(
"SELECT ts, price, ema(price) OVER (ORDER BY ts) FROM stock_prices"
).await?;
Сравнение: UDF vs UDAF vs UDWF
Полный пример: медианная агрегация
use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, AggregateUDF, Signature, Volatility};
use datafusion::arrow::array::{ArrayRef, Float64Array};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{Result, ScalarValue};
use std::any::Any;
#[derive(Debug)]
struct MedianAccumulator {
values: Vec<f64>,
}
impl Accumulator for MedianAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
// Сериализуем все значения как List
Ok(vec![ScalarValue::List(
ScalarValue::new_list(
&self.values.iter().map(|v| ScalarValue::Float64(Some(*v))).collect::<Vec<_>>(),
&DataType::Float64,
)
)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = values[0].as_any().downcast_ref::<Float64Array>().unwrap();
for i in 0..array.len() {
if !array.is_null(i) {
self.values.push(array.value(i));
}
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
// Разворачиваем List-состояния из других партиций
let list_array = &states[0];
// Упрощённо: итерируем по вложенным значениям
let values = list_array.as_any().downcast_ref::<Float64Array>();
if let Some(arr) = values {
for i in 0..arr.len() {
if !arr.is_null(i) {
self.values.push(arr.value(i));
}
}
}
Ok(())
}
fn evaluate(&self) -> Result<ScalarValue> {
if self.values.is_empty() {
return Ok(ScalarValue::Float64(None));
}
let mut sorted = self.values.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = sorted.len() / 2;
let median = if sorted.len() % 2 == 0 {
(sorted[mid - 1] + sorted[mid]) / 2.0
} else {
sorted[mid]
};
Ok(ScalarValue::Float64(Some(median)))
}
fn size(&self) -> usize {
std::mem::size_of::<Self>() + self.values.capacity() * std::mem::size_of::<f64>()
}
}
#[derive(Debug)]
struct MedianUdaf {
signature: Signature,
}
impl AggregateUDFImpl for MedianUdaf {
fn as_any(&self) -> &dyn Any { self }
fn name(&self) -> &str { "median" }
fn signature(&self) -> &Signature { &self.signature }
fn return_type(&self, _: &[DataType]) -> Result<DataType> { Ok(DataType::Float64) }
fn accumulator(&self, _: &datafusion::logical_expr::AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MedianAccumulator { values: vec![] }))
}
fn state_fields(&self, _: datafusion::logical_expr::StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![Field::new("values", DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), true)])
}
}
Медианная агрегация хранит все значения в памяти. Для больших датасетов это может вызвать OOM. В production используйте approximate median (t-digest) или ограничивайте размер через size().
Итоги
-
Accumulatortrait управляет жизненным циклом агрегации:state→update_batch→merge_batch→evaluate -
AggregateUDFImpl+Accumulator— полный контроль;create_udaf— быстрый способ -
state_fields()описывает промежуточное состояние для сериализации между партициями -
PartitionEvaluatorдля UDWF обрабатывает строки в контексте оконного фрейма -
WindowUDFImpl+PartitionEvaluator— оконные функции сevaluate_allилиevaluate -
size()в Accumulator — контроль памяти при параллельных агрегациях с большим числом групп