Custom hooks — BaseHook subclass, connection lookup
Hook в Airflow — это typed client для external service, использующий Connection из metadata DB как источник credentials. Hook abstraction позволяет писать operators, которые не знают где конкретно живут creds (Postgres connections, K8s secrets, AWS Secrets Manager — все через единый API). Этот урок показывает анатомию BaseHook, паттерн get_conn(), full example custom hook для proprietary REST API, и testing через mocking.
Зачем нужен Hook (vs прямое использование клиента)
Допустим, вы написали task, которая делает HTTP requests к internal API:
@task
def fetch_data():
import requests
response = requests.get(
"https://internal-api.corp/data"
headers={"Authorization": f"Bearer {os.environ['API_TOKEN']}"},
timeout=30,
)
return response.json()
Проблемы:
- Token из env var — hardcoded path, не из Airflow secrets backend
- URL hardcoded — staging/prod variants нужно сами обрабатывать
- Нет переиспользования между tasks
- Не тестируется без mocking всего
requests - В UI Connections нет регистрации — DevOps team не знает что хук использует token
Hook решает всё это:
@task
def fetch_data():
hook = InternalApiHook(conn_id="internal_api_prod") # ← lookup в Connections
return hook.get_data()
Connection internal_api_prod хранится в metadata DB (или Secrets Manager) с URL, token, options. UI показывает её в /connection/list. Hook абстрагирует HTTP details.
BaseHook — что внутри
BaseHook — minimum abstraction. Real implementation:
# airflow/hooks/base.py — упрощённо
class BaseHook(LoggingMixin):
"""Abstract base class для всех hooks."""
conn_name_attr: str = "conn_id"
def __init__(self, *, conn_id: str = None):
super().__init__()
self.conn_id = conn_id
@classmethod
def get_connection(cls, conn_id: str) -> Connection:
"""Get connection from secrets backends, env, or DB."""
from airflow.models.connection import Connection
return Connection.get_connection_from_secrets(conn_id)
@classmethod
def get_hook(cls, conn_id: str) -> "BaseHook":
"""Factory method based on connection type."""
connection = cls.get_connection(conn_id)
return connection.get_hook()
def get_conn(self) -> Any:
"""Return the connection client (raw)."""
raise NotImplementedError
Ключевые методы:
get_connection(conn_id)— class method, выполняет lookup через chain: secrets backend → env var (AIRFLOW_CONN_<UPPER>) → metadata DB.get_conn()— instance method, возвращает raw client (напримерpsycopg2.Connectionилиboto3.Client). Каждый hook implements его.get_hook()— factory: получить hook без явного импорта класса (через connection type).
Lookup chain — где Airflow ищет connection
При BaseHook.get_connection("my_conn") выполняется:
Chain priority — критично для production: Secrets Manager > env vars > DB. Если та же conn_id есть в Vault и DB — Vault выигрывает.
with-statement — __enter__ и __exit__
Full example: custom hook для proprietary REST API
Допустим, у вас internal API “DataCorp” с REST endpoints, bearer auth, retry-able. Создаём hook.
# my_org_provider/hooks/datacorp.py
from __future__ import annotations
from typing import Any
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from airflow.hooks.base import BaseHook
from airflow.exceptions import AirflowException
class DataCorpHook(BaseHook):
"""Hook for DataCorp internal API."""
conn_name_attr = "datacorp_conn_id"
default_conn_name = "datacorp_default"
conn_type = "datacorp"
hook_name = "DataCorp"
def __init__(self, datacorp_conn_id: str = default_conn_name, timeout: int = 30):
super().__init__()
self.datacorp_conn_id = datacorp_conn_id
self.timeout = timeout
self._session: requests.Session | None = None
def get_conn(self) -> requests.Session:
"""Return authenticated requests.Session."""
if self._session is not None:
return self._session
connection = self.get_connection(self.datacorp_conn_id)
# Connection fields:
# - host: API base URL
# - password: bearer token
# - extra: optional JSON with timeout, retry config
session = requests.Session()
# Authorization
session.headers.update({
"Authorization": f"Bearer {connection.password}",
"User-Agent": "Airflow-DataCorpHook/1.0",
})
# Retry policy
retry = Retry(
total=3,
backoff_factor=0.5,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["GET", "POST", "PUT"],
)
session.mount("https://", HTTPAdapter(max_retries=retry))
self._base_url = connection.host.rstrip("/")
self._session = session
return session
def get_resource(self, resource_id: str) -> dict[str, Any]:
"""GET /resources/{id}."""
session = self.get_conn()
url = f"{self._base_url}/resources/{resource_id}"
response = session.get(url, timeout=self.timeout)
if response.status_code == 404:
raise AirflowException(f"Resource {resource_id} not found")
response.raise_for_status()
return response.json()
def list_resources(self, page_size: int = 100) -> list[dict]:
"""Paginated list via /resources."""
session = self.get_conn()
url = f"{self._base_url}/resources"
params = {"page_size": page_size, "cursor": None}
results = []
while True:
response = session.get(url, params=params, timeout=self.timeout)
response.raise_for_status()
data = response.json()
results.extend(data["items"])
if not data.get("next_cursor"):
break
params["cursor"] = data["next_cursor"]
return results
def create_resource(self, payload: dict) -> dict:
"""POST /resources."""
session = self.get_conn()
response = session.post(
f"{self._base_url}/resources"
json=payload,
timeout=self.timeout,
)
response.raise_for_status()
return response.json()
@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
"""Customize UI Connection form (Airflow 2.0+)."""
return {
"hidden_fields": ["schema", "port", "login"],
"relabeling": {
"host": "API Base URL",
"password": "Bearer Token",
},
"placeholders": {
"host": "https://api.datacorp.example.com",
"password": "Your bearer token",
},
}
@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Add extra fields to UI Connection form (2.0+)."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField
return {
"api_version": StringField(
lazy_gettext("API Version"),
widget=BS3TextFieldWidget(),
default="v1",
),
}
Key patterns:
conn_type = "datacorp"— это значение появится в UI Connection type dropdown (если зарегистрирован в provider).get_conn()lazily создаёт session и cache-ит — несколько вызовов hook методов используют одну session.- Retry policy на session level — все requests automatically retry на 429/5xx.
get_ui_field_behaviour— кастомизация UI Connection form (hide unused fields, relabel).get_connection_form_widgets— добавление custom fields (например, API version).
Registration
File-based plugin
# $AIRFLOW_HOME/plugins/datacorp_plugin.py
from airflow.plugins_manager import AirflowPlugin
from my_org_provider.hooks.datacorp import DataCorpHook
class DataCorpPlugin(AirflowPlugin):
name = "datacorp_plugin"
hooks = [DataCorpHook]
После restart — from airflow.hooks.datacorp import DataCorpHook работает.
Provider package
# my_org_provider/__init__.py
def get_provider_info():
return {
"package-name": "my-org-airflow-provider",
"name": "MyOrg Provider",
"versions": ["1.0.0"],
"hook-class-names": ["my_org_provider.hooks.datacorp.DataCorpHook"],
"connection-types": [
{
"connection-type": "datacorp",
"hook-class-name": "my_org_provider.hooks.datacorp.DataCorpHook",
}
],
}
После pip install — Hook автоматически доступен, UI показывает “DataCorp” в Connection type dropdown.
Использование в DAG
from airflow.decorators import dag, task
from datetime import datetime
from my_org_provider.hooks.datacorp import DataCorpHook
@dag(schedule="@daily", start_date=datetime(2026, 1, 1), catchup=False)
def datacorp_etl():
@task
def fetch_resources() -> list[dict]:
hook = DataCorpHook(datacorp_conn_id="datacorp_prod")
return hook.list_resources()
@task
def enrich(resources: list[dict]) -> list[dict]:
hook = DataCorpHook(datacorp_conn_id="datacorp_prod")
enriched = []
for r in resources:
detail = hook.get_resource(r["id"])
enriched.append({**r, **detail})
return enriched
@task
def save(enriched: list[dict]):
# write to warehouse
pass
save(enrich(fetch_resources()))
datacorp_etl()
Чисто, типизированно, secrets из Connection.
Testing custom hook
Тесты через unittest.mock и pytest:
# tests/hooks/test_datacorp.py
from unittest.mock import patch, MagicMock
import pytest
from my_org_provider.hooks.datacorp import DataCorpHook
@pytest.fixture
def mock_connection():
"""Mock the Airflow Connection lookup."""
conn = MagicMock()
conn.host = "https://api.test"
conn.password = "test_token"
return conn
def test_get_resource(mock_connection):
with patch.object(DataCorpHook, "get_connection", return_value=mock_connection):
with patch("requests.Session.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"id": "r1", "name": "Test"}
mock_get.return_value = mock_response
hook = DataCorpHook(datacorp_conn_id="test_conn")
result = hook.get_resource("r1")
assert result == {"id": "r1", "name": "Test"}
mock_get.assert_called_once_with(
"https://api.test/resources/r1"
timeout=30,
)
def test_resource_not_found(mock_connection):
with patch.object(DataCorpHook, "get_connection", return_value=mock_connection):
with patch("requests.Session.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 404
mock_get.return_value = mock_response
from airflow.exceptions import AirflowException
hook = DataCorpHook(datacorp_conn_id="test_conn")
with pytest.raises(AirflowException, match="not found"):
hook.get_resource("r1")
Pattern:
- Mock
get_connectionчтобы не depend на real Airflow metadata DB. - Mock
requests.Session.get/postдля no real HTTP. - Verify call args + return value.
Connection definition
В UI / CLI создать connection:
airflow connections add datacorp_prod \
--conn-type datacorp \
--conn-host "https://api.datacorp.example.com" \
--conn-password "$DATACORP_TOKEN" \
--conn-extra '{"api_version": "v1"}'
Или через env var:
export AIRFLOW_CONN_DATACORP_PROD="datacorp://:$DATACORP_TOKEN@api.datacorp.example.com?api_version=v1"
URI format: <conn_type>://<login>:<password>@<host>:<port>/<schema>?<extra params>.
Production gotchas
1. get_conn() cache в self._session — single-task-instance lifetime. Между task runs hook re-instantiated. Не пытайтесь cache между tasks через class-level attributes — это broken in distributed setup (multiple workers).
2. Connection lookup происходит при get_connection(), не при __init__. Если в task у вас:
hook = DataCorpHook(conn_id="...") # ← не делает lookup
а потом hook не используется — Connection не запрошена. Это хорошо для performance.
3. Heavy imports в hook module loaded для каждого DAG parse. Если hook импортирует pandas, scipy — это раздувает DAG parsing time. Move heavy imports внутрь methods (lazy).
4. Secrets backend caching — [secrets] use_cache=True (2.7+). Лучше включить — без cache каждый get_connection делает full backend call. Cache TTL cache_ttl_seconds default 900.
5. Custom hook не работает в DAG if plugin не loaded. На worker scheduler — все процессы должны иметь plugin available. В Kubernetes — image должен включать provider package.