8000 fix: lazy load jailbreak detection dependencies by jeffreyscarpenter · Pull Request #1223 · NVIDIA/NeMo-Guardrails · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Dismiss alert

fix: lazy load jailbreak detection dependencies #1223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
2 changes: 1 addition & 1 deletion docs/user-guides/llm/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
LLMs
===
====

.. toctree::
:maxdepth: 2
Expand Down
24 changes: 10 additions & 14 deletions nemoguardrails/library/jailbreak_detection/model_based/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,32 @@
# limitations under the License.

import os
import pickle
from functools import lru_cache
from pathlib import Path
from typing import Tuple, Union

import numpy as np
from sklearn.ensemble import RandomForestClassifier

from nemoguardrails.library.jailbreak_detection.model_based.models import (
JailbreakClassifier,
)

models_path = os.environ.get("EMBEDDING_CLASSIFIER_PATH")

# When we add NIM support, will need to remove this check.
if models_path is None:
raise EnvironmentError(
"Please set the EMBEDDING_CLASSIFIER_PATH environment variable to point to the Classifier model_based folder"
)


@lru_cache()
def initialize_model(classifier_path: str = models_path) -> JailbreakClassifier:
def initialize_model(classifier_path: str = models_path) -> "JailbreakClassifier":
"""
Initialize the global classifier model according to the configuration provided.
Args
classifier_path: Path to the classi 8000 fier model
Returns
jailbreak_classifier: JailbreakClassifier object combining embedding model and NemoGuard JailbreakDetect RF
"""
if classifier_path is None:
raise EnvironmentError(
"Please set the EMBEDDING_CLASSIFIER_PATH environment variable to point to the Classifier model_based folder"
)

from nemoguardrails.library.jailbreak_detection.model_based.models import (
JailbreakClassifier,
)

jailbreak_classifier = JailbreakClassifier(
str(Path(classifier_path).joinpath("snowflake.pkl"))
Expand All @@ -54,7 +50,7 @@ def initialize_model(classifier_path: str = models_path) -> JailbreakClassifier:

def check_jailbreak(
prompt: str,
classifier: JailbreakClassifier = None,
classifier=None,
) -> dict:
"""
Use embedding-based jailbreak detection model to check for the presence of a jailbreak
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
# limitations under the License.

import os
import pickle
from typing import Tuple

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer


class SnowflakeEmbed:
def __init__(self):
import torch
from transformers import AutoModel, AutoTokenizer

self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(
"snowflake/snowflake-arctic-embed-m-long"
Expand Down Expand Up @@ -71,6 +71,8 @@ def __call__(self, text: str):

class JailbreakClassifier:
def __init__(self, random_forest_path: str):
import pickle

self.embed = SnowflakeEmbed()
with open(random_forest_path, "rb") as fd:
self.classifier = pickle.load(fd)
Expand Down
204 changes: 204 additions & 0 deletions tests/test_jailbreak_model_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import types
from unittest import mock

import pytest

# Test 1: Lazy import behavior


def test_lazy_import_does_not_require_heavy_deps():
"""
Importing the checks module should not require torch, transformers, or sklearn unless model-based classifier is used.
"""
with mock.patch.dict(
sys.modules, {"torch": None, "transformers": None, "sklearn": None}
):
import nemoguardrails.library.jailbreak_detection.model_based.checks as checks

# Just importing and calling unrelated functions should not raise ImportError
assert hasattr(checks, "initialize_model")


# Test 2: Model-based classifier instantiation requires dependencies


def test_model_based_classifier_imports(monkeypatch):
"""
Instantiating JailbreakClassifier should require sklearn and pickle, and use SnowflakeEmbed which requires torch/transformers.
"""
# Mock dependencies
fake_rf = mock.MagicMock()
fake_embed = mock.MagicMock(return_value=[0.0 8000 ])
fake_pickle = types.SimpleNamespace(load=mock.MagicMock(return_value=fake_rf))
fake_snowflake = mock.MagicMock(return_value=fake_embed)

monkeypatch.setitem(
sys.modules,
"sklearn.ensemble",
types.SimpleNamespace(RandomForestClassifier=mock.MagicMock()),
)
monkeypatch.setitem(sys.modules, "pickle", fake_pickle)
monkeypatch.setitem(sys.modules, "torch", mock.MagicMock())
monkeypatch.setitem(sys.modules, "transformers", mock.MagicMock())

# Patch SnowflakeEmbed to avoid real model loading
import nemoguardrails.library.jailbreak_detection.model_based.models as models

monkeypatch.setattr(models, "SnowflakeEmbed", fake_snowflake)

# mocking file operations to avoid Windows permission issues
mock_open = mock.mock_open()
with mock.patch("builtins.open", mock_open):
# Should not raise
classifier = models.JailbreakClassifier("fake_model_path.pkl")
assert classifier is not None
# Should be callable
result = classifier("test")
assert isinstance(result, tuple)


# Test 3: Error if dependencies missing when instantiating model-based classifier


def test_model_based_classifier_missing_deps(monkeypatch):
"""
If sklearn is missing, instantiating JailbreakClassifier should raise ImportError.
"""
monkeypatch.setitem(sys.modules, "sklearn.ensemble", None)

import nemoguardrails.library.jailbreak_detection.model_based.models as models

# to avoid Windows permission issues
mock_open = mock.mock_open()
with mock.patch("builtins.open", mock_open):
with pytest.raises(ImportError):
models.JailbreakClassifier("fake_model_path.pkl")


# Test 4: Error when classifier_path is None


def test_initialize_model_with_none_classifier_path():
"""
initialize_model should raise EnvironmentError when classifier_path is None.
"""
import nemoguardrails.library.jailbreak_detection.model_based.checks as checks

with pytest.raises(EnvironmentError) as exc_info:
checks.initialize_model(classifier_path=None)

assert "Please set the EMBEDDING_CLASSIFIER_PATH environment variable" in str(
exc_info.value
)


# Test 5: SnowflakeEmbed initialization and call with torch imports


def test_snowflake_embed_torch_imports(monkeypatch):
"""
Test that SnowflakeEmbed properly imports torch and transformers when needed.
"""
# Mock torch and transformers
mock_torch = mock.MagicMock()
mock_torch.cuda.is_available.return_value = False
mock_transformers = mock.MagicMock()

mock_tokenizer = mock.MagicMock()
mock_model = mock.MagicMock()
mock_transformers.AutoTokenizer.from_pretrained.return_value = mock_tokenizer
mock_transformers.AutoModel.from_pretrained.return_value = mock_model

monkeypatch.setitem(sys.modules, "torch", mock_torch)
monkeypatch.setitem(sys.modules, "transformers", mock_transformers)

import nemoguardrails.library.jailbreak_detection.model_based.models as models

embed = models.SnowflakeEmbed()
assert embed.device == "cpu" # as we mocked cuda.is_available() = False

mock_tokens = mock.MagicMock()
mock_tokens.to.return_value = mock_tokens
mock_tokenizer.return_value = mock_tokens

import numpy as np

fake_embedding = np.array([1.0, 2.0, 3.0])

# the code does self.model(**tokens)[0][:, 0]
# so we need to mock this properly
mock_tensor_output = mock.MagicMock()
mock_tensor_output.detach.return_value.cpu.return_value.squeeze.return_value.numpy.return_value = (
fake_embedding
)

A93C mock_first_index = mock.MagicMock()
mock_first_index.__getitem__.return_value = mock_tensor_output # for [:, 0]

mock_model_output = mock.MagicMock()
mock_model_output.__getitem__.return_value = mock_first_index # for [0]

mock_model.return_value = mock_model_output

result = embed("test text")
assert isinstance(result, np.ndarray)
assert np.array_equal(result, fake_embedding)


# Test 6: Check jailbreak function with classifier parameter


def test_check_jailbreak_with_classifier():
"""
Test check_jailbreak function when classifier is provided.
"""
import nemoguardrails.library.jailbreak_detection.model_based.checks as checks

mock_classifier = mock.MagicMock()
# jailbreak detected with score 0.9
mock_classifier.return_value = (True, 0.9)

result = checks.check_jailbreak("test prompt", classifier=mock_classifier)

assert result == {"jailbreak": True, "score": 0.9}
mock_classifier.assert_called_once_with("test prompt")


# Test 7: Check jailbreak function without classifier parameter (uses initialize_model)


def test_check_jailbreak_without_classifier(monkeypatch):
"""
Test check_jailbreak function when no classifier is provided, it should call initialize_model.
"""
import nemoguardrails.library.jailbreak_detection.model_based.checks as checks

# mock initialize_model to return a mock classifier
mock_classifier = mock.MagicMock()
# no jailbreak
mock_classifier.return_value = (False, -0.5)
mock_initialize_model = mock.MagicMock(return_value=mock_classifier)

monkeypatch.setattr(checks, "initialize_model", mock_initialize_model)

result = checks.check_jailbreak("safe prompt")

assert result == {"jailbreak": False, "score": -0.5}
mock_initialize_model.assert_called_once()
mock_classifier.assert_called_once_with("safe prompt")
0