8000 Use decorator for runtime patching by ismailsimsek · Pull Request #59 · memiiso/opendbt · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Use decorator for runtime patching #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests-dbt-version.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
# FIX for protobuf issue: https://github.com/dbt-labs/dbt-core/issues/9759
pip install -q "apache-airflow" "protobuf>=4.25.3,<5.0.0" "opentelemetry-proto<1.28.0" --prefer-binary
pip install -q .[test] --prefer-binary
pip install -q dbt-core==${{ inputs.dbt-version }}.* dbt-duckdb==${{ inputs.dbt-version }}.* --force-reinstall --upgrade
python --version
python -c "from dbt.version import get_installed_version as get_dbt_version;print(f'dbt version={get_dbt_version()}')"
python -m compileall -f opendbt
Expand Down
17 changes: 3 additions & 14 deletions opendbt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
import logging
import os
import sys
from pathlib import Path

######################
from opendbt.dbt import patch_dbt

patch_dbt()
# IMPORTANT! this will import the overrides, and activates the patches
# IMPORTANT! `opendbt.dbt` import needs to happen before any `dbt` import
from opendbt.dbt import *
from opendbt.logger import OpenDbtLogger
from opendbt.utils import Utils
######################

from dbt.cli.main import dbtRunner as DbtCliRunner
from dbt.cli.main import dbtRunnerResult
from dbt.config import PartialProject
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import RunResult
from dbt.exceptions import DbtRuntimeError
from dbt.task.base import get_nearest_project_dir

class OpenDbtCli:

Expand Down
1 change: 0 additions & 1 deletion opendbt/airflow/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def manifest(self):
# static_url_path='/dbtdocsview'
)


class AirflowDbtDocsPlugin(AirflowPlugin):
name = "DBT Docs Plugin"
flask_blueprints = [bp]
Expand Down
40 changes: 21 additions & 19 deletions opendbt/dbt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,39 @@
import dbt
from dbt import version
from packaging.version import Version

import opendbt.dbt.shared.cli.main
from opendbt.runtime_patcher import RuntimePatcher

def patch_dbt():
# ================================================================================================================
# Monkey Patching! Override dbt lib code with new one
# ================================================================================================================
try:
dbt_version = Version(version.get_installed_version().to_version_string(skip_matcher=True))
if Version("1.6.0") <= dbt_version < Version("1.8.0"):
from opendbt.dbt.v17.adapters.factory import OpenDbtAdapterContainer
from opendbt.dbt.v17.config.runtime import OpenDbtRuntimeConfig
dbt.config.RuntimeConfig = OpenDbtRuntimeConfig
from opendbt.dbt.v17.task.docs.generate import OpenDbtGenerateTask
dbt.task.generate.GenerateTask = OpenDbtGenerateTask
from opendbt.dbt.v17.adapters.factory import OpenDbtAdapterContainer
dbt.adapters.factory.FACTORY = OpenDbtAdapterContainer()
from opendbt.dbt.v17.task.run import ModelRunner
dbt.task.run.ModelRunner = ModelRunner
elif Version("1.8.0") <= dbt_version < Version("1.10.0"):
from opendbt.dbt.v18.adapters.factory import OpenDbtAdapterContainer
from opendbt.dbt.v18.config.runtime import OpenDbtRuntimeConfig
dbt.config.RuntimeConfig = OpenDbtRuntimeConfig
from opendbt.dbt.v18.task.docs.generate import OpenDbtGenerateTask
dbt.task.docs.generate.GenerateTask = OpenDbtGenerateTask
from opendbt.dbt.v18.adapters.factory import OpenDbtAdapterContainer
dbt.adapters.factory.FACTORY = OpenDbtAdapterContainer()
from opendbt.dbt.v18.task.run import ModelRunner
dbt.task.run.ModelRunner = ModelRunner
else:
raise Exception(
f"Unsupported dbt version {dbt_version}, please make sure dbt version is supported/integrated by opendbt")

RuntimePatcher(module_name="dbt.adapters.factory").patch_attribute(attribute_name="FACTORY",
new_value=OpenDbtAdapterContainer())
# shared code patches
import opendbt.dbt.shared.cli.main
dbt.cli.main.sqlfluff = opendbt.dbt.shared.cli.main.sqlfluff
dbt.cli.main.sqlfluff_lint = opendbt.dbt.shared.cli.main.sqlfluff_lint
dbt.cli.main.sqlfluff_fix = opendbt.dbt.shared.cli.main.sqlfluff_fix
from opendbt.dbt.shared.cli.main import sqlfluff
from opendbt.dbt.shared.cli.main import sqlfluff_lint
from opendbt.dbt.shared.cli.main import sqlfluff_fix

# dbt imports
from dbt.cli.main import dbtRunner as DbtCliRunner
from dbt.cli.main import dbtRunnerResult
from dbt.config import PartialProject
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import RunResult
from dbt.exceptions import DbtRuntimeError
from dbt.task.base import get_nearest_project_dir
except:
raise
4 changes: 4 additions & 0 deletions opendbt/dbt/shared/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from dbt.cli.main import global_flags, cli

from opendbt.dbt.shared.task.sqlfluff import SqlFluffTasks
from opendbt.runtime_patcher import PatchFunction


# dbt docs
@cli.group()
@click.pass_context
@global_flags
@PatchFunction(module_name="dbt.cli.main", target_name="sqlfluff")
def sqlfluff(ctx, **kwargs):
"""Generate or serve the documentation website for your project"""

Expand Down Expand Up @@ -45,6 +47,7 @@ def sqlfluff(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest(write=False)
@PatchFunction(module_name="dbt.cli.main", target_name="sqlfluff_lint")
def sqlfluff_lint(ctx, **kwargs):
"""Generate the documentation website for your project"""
task = SqlFluffTasks(
Expand Down Expand Up @@ -90,6 +93,7 @@ def sqlfluff_lint(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest(write=False)
@PatchFunction(module_name="dbt.cli.main", target_name="sqlfluff_lint")
def sqlfluff_fix(ctx, **kwargs):
"""Generate the documentation website for your project"""
task = SqlFluffTasks(
Expand Down
4 changes: 4 additions & 0 deletions opendbt/dbt/v17/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from dbt.events.types import AdapterRegistered
from dbt.semver import VersionSpecifier

from opendbt.runtime_patcher import PatchClass


@PatchClass(module_name="dbt.adapters.factory", target_name="AdapterContainer")
class OpenDbtAdapterContainer(factory.AdapterContainer):
DBT_CUSTOM_ADAPTER_VAR = 'dbt_custom_adapter'

def register_adapter(self, config: 'AdapterRequiredConfig') -> None:
# ==== CUSTOM CODE ====
# ==== END CUSTOM CODE ====
Expand Down
7 changes: 4 additions & 3 deletions opendbt/dbt/v17/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from dbt.config import RuntimeConfig
from dbt.config.project import path_exists, _load_yaml
from dbt.constants import DEPENDENCIES_FILE_NAME
from dbt.exceptions import (
DbtProjectError, NonUniquePackageNameError,
)
from dbt.exceptions import DbtProjectError, NonUniquePackageNameError
from typing_extensions import override

from opendbt.runtime_patcher import PatchClass

def load_yml_dict(file_path):
ret = {}
Expand All @@ -19,6 +18,8 @@ def load_yml_dict(file_path):

# pylint: disable=too-many-ancestors
@dataclass
@PatchClass(module_name="dbt.config", target_name="RuntimeConfig")
@PatchClass(module_name="dbt.cli.requires", target_name="RuntimeConfig")
class OpenDbtRuntimeConfig(RuntimeConfig):
def load_dependence_projects(self):
dependencies_yml_dict = load_yml_dict(f"{self.project_root}/{DEPENDENCIES_FILE_NAME}")
Expand Down
3 changes: 3 additions & 0 deletions opendbt/dbt/v17/task/docs/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import click
from dbt.task.generate import GenerateTask

from opendbt.runtime_patcher import PatchClass


@PatchClass(module_name="dbt.task.generate", target_name="GenerateTask")
class OpenDbtGenerateTask(GenerateTask):

def deploy_user_index_html(self):
Expand Down
3 changes: 3 additions & 0 deletions opendbt/dbt/v17/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
)
from dbt.task import run

from opendbt.runtime_patcher import PatchClass


@PatchClass(module_name="dbt.task.run", target_name="ModelRunner")
class ModelRunner(run.ModelRunner):

def print_result_adapter_response(self, result):
Expand Down
4 changes: 4 additions & 0 deletions opendbt/dbt/v18/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
from dbt_common.events.base_types import EventLevel
from dbt_common.events.functions import fire_event

from opendbt.runtime_patcher import PatchClass


@PatchClass(module_name="dbt.adapters.factory", target_name="AdapterContainer")
class OpenDbtAdapterContainer(factory.AdapterContainer):
DBT_CUSTOM_ADAPTER_VAR = 'dbt_custom_adapter'

def register_adapter(
self,
config: 'AdapterRequiredConfig',
Expand Down
7 changes: 4 additions & 3 deletions opendbt/dbt/v18/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from dbt.config import RuntimeConfig
from dbt.config.project import load_yml_dict
from dbt.constants import DEPENDENCIES_FILE_NAME
from dbt.exceptions import (
DbtProjectError, NonUniquePackageNameError,
)
from dbt.exceptions import DbtProjectError, NonUniquePackageNameError
from typing_extensions import override

from opendbt.runtime_patcher import PatchClass

# pylint: disable=too-many-ancestors
@dataclass
@PatchClass(module_name="dbt.config", target_name="RuntimeConfig")
@PatchClass(module_name="dbt.cli.requires", target_name="RuntimeConfig")
class OpenDbtRuntimeConfig(RuntimeConfig):
def load_dependence_projects(self):
dependencies_yml_dict = load_yml_dict(f"{self.project_root}/{DEPENDENCIES_FILE_NAME}")
Expand Down
3 changes: 3 additions & 0 deletions opendbt/dbt/v18/task/docs/generate.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import click
from dbt.task.docs.generate import GenerateTask

from opendbt.runtime_patcher import PatchClass


@PatchClass(module_name="dbt.task.docs.generate", target_name="GenerateTask")
class OpenDbtGenerateTask(GenerateTask):

def deploy_user_index_html(self):
Expand Down
3 changes: 3 additions & 0 deletions opendbt/dbt/v18/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from dbt_common.events.base_types import EventLevel
from dbt_common.events.functions import fire_event

from opendbt.runtime_patcher import PatchClass


@PatchClass(module_name="dbt.task.run", target_name="ModelRunner")
class ModelRunner(run.ModelRunner):

def print_result_adapter_response(self, result):
Expand Down
19 changes: 8 additions & 11 deletions opendbt/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,11 @@ def email_dbt_test_callback(event: "EventMsg"):
--------------- full data ---------------
dbt data: {event.data}
"""
try:
# send email alert using airflow
from airflow.utils.email import send_email
send_email(
subject=email_subject,
to="my-slack-notification-channel@slack.com",
html_content=email_html_content
)
except Exception as _:
# Used by unittest, expecting airflow error
logging.getLogger('dbtcallbacks').error("Airflow send_email failed! this is expected for unit testing!")
# @TODO send email alert using airflow
# from airflow.utils.email import send_email
# send_email(
# subject=email_subject,
# to="my-slack-notification-channel@slack.com",
# html_content=email_html_content
# )
logging.getLogger('dbtcallbacks').error("Callback email sent!")
Loading
0