diff --git a/nemoguardrails/library/injection_detection/actions.py b/nemoguardrails/library/injection_detection/actions.py
index 947b55d37..7a85e2993 100644
--- a/nemoguardrails/library/injection_detection/actions.py
+++ b/nemoguardrails/library/injection_detection/actions.py
@@ -32,7 +32,7 @@
import re
from functools import lru_cache
from pathlib import Path
-from typing import Dict, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, TypedDict, Union
yara = None
try:
@@ -49,6 +49,12 @@
log = logging.getLogger(__name__)
+class InjectionDetectionResult(TypedDict):
+ is_injection: bool
+ text: str
+ detections: List[str]
+
+
def _check_yara_available():
if yara is None:
raise ImportError(
@@ -197,13 +203,13 @@ def _load_rules(
}
rules = yara.compile(filepaths=rules_to_load)
except yara.SyntaxError as e:
- msg = f"Encountered SyntaxError: {e}"
+ msg = f"Failed to initialize injection detection due to configuration or YARA rule error: YARA compilation failed: {e}"
log.error(msg)
- raise e
+ return None
return rules
-def _omit_injection(text: str, matches: list["yara.Match"]) -> str:
+def _omit_injection(text: str, matches: list["yara.Match"]) -> Tuple[bool, str]:
"""
Attempts to strip the offending injection attempts from the provided text.
@@ -216,14 +222,18 @@ def _omit_injection(text: str, matches: list["yara.Match"]) -> str:
matches (list['yara.Match']): A list of YARA rule matches.
Returns:
- str: The text with the detected injections stripped out.
+ Tuple[bool, str]: A tuple containing:
+ - bool: True if injection was detected and modified,
+ False if the text is safe (i.e., not modified).
+ - str: The text, with detected injections stripped out if modified.
Raises:
ImportError: If the yara module is not installed.
"""
- # Copy the text to a placeholder variable
+ original_text = text
modified_text = text
+ is_injection = False
for match in matches:
if match.strings:
for match_string in match.strings:
@@ -234,10 +244,16 @@ def _omit_injection(text: str, matches: list["yara.Match"]) -> str:
modified_text = modified_text.replace(plaintext, "")
except (AttributeError, UnicodeDecodeError) as e:
log.warning(f"Error processing match: {e}")
- return modified_text
+
+ if modified_text != original_text:
+ is_injection = True
+ return is_injection, modified_text
+ else:
+ is_injection = False
+ return is_injection, original_text
-def _sanitize_injection(text: str, matches: list["yara.Match"]) -> str:
+def _sanitize_injection(text: str, matches: list["yara.Match"]) -> Tuple[bool, str]:
"""
Attempts to sanitize the offending injection attempts in the provided text.
This is done by 'de-fanging' the offending content, transforming it into a state that will not execute
@@ -253,19 +269,27 @@ def _sanitize_injection(text: str, matches: list["yara.Match"]) -> str:
matches (list['yara.Match']): A list of YARA rule matches.
Returns:
- str: The text with the detected injections sanitized.
+ Tuple[bool, str]: A tuple containing:
+ - bool: True if injection was detected, False otherwise.
+ - str: The sanitized text, or original text depending on sanitization outcome.
+ Currently, this function will always raise NotImplementedError.
Raises:
NotImplementedError: If the sanitization logic is not implemented.
ImportError: If the yara module is not installed.
"""
-
raise NotImplementedError(
"Injection sanitization is not yet implemented. Please use 'reject' or 'omit'"
)
+ # Hypothetical logic if implemented, to match existing behavior in injection_detection:
+ # sanitized_text_attempt = "..." # result of sanitization
+ # if sanitized_text_attempt != text:
+ # return True, text # Original text returned, marked as injection detected
+ # else:
+ # return False, sanitized_text_attempt
-def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, str]:
+def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, List[str]]:
"""
Detects whether the provided text contains potential injection attempts.
@@ -277,8 +301,9 @@ def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, str]:
rules ('yara.Rules'): The loaded YARA rules.
Returns:
- bool: True if attempted exploitation is detected, False otherwise.
- str: list of matches as a string
+ Tuple[bool, List[str]]: A tuple containing:
+ - bool: True if attempted exploitation is detected, False otherwise.
+ - List[str]: List of matched rule names.
Raises:
ValueError: If the `action` parameter in the configuration is invalid.
@@ -289,18 +314,20 @@ def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, str]:
log.warning(
"reject_injection guardrail was invoked but no rules were specified in the InjectionDetection config."
)
- return False, ""
+ return False, []
matches = rules.match(data=text)
if matches:
- matches_string = ", ".join([match_name.rule for match_name in matches])
- log.info(f"Input matched on rule {matches_string}.")
- return True, matches_string
+ matched_rules = [match_name.rule for match_name in matches]
+ log.info(f"Input matched on rule {', '.join(matched_rules)}.")
+ return True, matched_rules
else:
- return False, ""
+ return False, []
@action()
-async def injection_detection(text: str, config: RailsConfig) -> str:
+async def injection_detection(
+ text: str, config: RailsConfig
+) -> InjectionDetectionResult:
"""
Detects and mitigates potential injection attempts in the provided text.
@@ -310,45 +337,68 @@ async def injection_detection(text: str, config: RailsConfig) -> str:
Args:
text (str): The text to check for command injection.
+
config (RailsConfig): The Rails configuration object containing injection detection settings.
Returns:
- str: The sanitized or original text, depending on the action specified in the configuration.
+ InjectionDetectionResult: A TypedDict containing:
+ - is_injection (bool): Whether an injection was detected. True if any injection is detected,
+ False if no injection is detected.
+ - text (str): The sanitized or original text
+ - detections (List[str]): List of matched rule names if any injection is detected
Raises:
ValueError: If the `action` parameter in the configuration is invalid.
NotImplementedError: If an unsupported action is encountered.
+ ImportError: If the yara module is not installed.
"""
_check_yara_available()
_validate_injection_config(config)
+
action_option, yara_path, rule_names, yara_rules = _extract_injection_config(config)
rules = _load_rules(yara_path, rule_names, yara_rules)
- if action_option == "reject":
- verdict, detections = _reject_injection(text, rules)
- if verdict:
- return f"I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of {detections}."
- else:
- return text
if rules is None:
log.warning(
"injection detection guardrail was invoked but no rules were specified in the InjectionDetection config."
)
- return text
- matches = rules.match(data=text)
- if matches:
- matches_string = ", ".join([match_name.rule for match_name in matches])
- log.info(f"Input matched on rule {matches_string}.")
- if action_option == "omit":
- return _omit_injection(text, matches)
- elif action_option == "sanitize":
- return _sanitize_injection(text, matches)
+ return InjectionDetectionResult(is_injection=False, text=text, detections=[])
+
+ if action_option == "reject":
+ is_injection, detected_rules = _reject_injection(text, rules)
+ return InjectionDetectionResult(
+ is_injection=is_injection, text=text, detections=detected_rules
+ )
+ else:
+ matches = rules.match(data=text)
+ if matches:
+ detected_rules_list = [match_name.rule for match_name in matches]
+ log.info(f"Input matched on rule {', '.join(detected_rules_list)}.")
+
+ if action_option == "omit":
+ is_injection, result_text = _omit_injection(text, matches)
+ return InjectionDetectionResult(
+ is_injection=is_injection,
+ text=result_text,
+ detections=detected_rules_list,
+ )
+ elif action_option == "sanitize":
+ # _sanitize_injection will raise NotImplementedError before returning a tuple.
+ # the assignment below is for structural consistency if it were implemented.
+ is_injection, result_text = _sanitize_injection(text, matches)
+ return InjectionDetectionResult(
+ is_injection=is_injection,
+ text=result_text,
+ detections=detected_rules_list,
+ )
+ else:
+ raise NotImplementedError(
+ f"Expected `action` parameter to be 'reject', 'omit', or 'sanitize' but got {action_option} instead."
+ )
+ # no matches found
else:
- # We should never ever hit this since we inspect the action option above, but putting an error here anyway.
- raise NotImplementedError(
- f"Expected `action` parameter to be 'omit' or 'sanitize' but got {action_option} instead."
+ return InjectionDetectionResult(
+ is_injection=False, text=text, detections=[]
)
- else:
- return text
diff --git a/nemoguardrails/library/injection_detection/flows.co b/nemoguardrails/library/injection_detection/flows.co
index 22ca9095f..26da02578 100644
--- a/nemoguardrails/library/injection_detection/flows.co
+++ b/nemoguardrails/library/injection_detection/flows.co
@@ -1,7 +1,19 @@
-# OUTPUT RAILS
-
flow injection detection
"""
Reject, omit, or sanitize injection attempts from the bot.
+ This rail operates on the $bot_message.
"""
- $bot_message = await InjectionDetectionAction(text=$bot_message)
+ response = await InjectionDetectionAction(text=$bot_message)
+ join_separator = ", "
+ injection_detection_action = $config.rails.config.injection_detection.action
+
+ if response["is_injection"]
+ if $config.enable_rails_exceptions
+ send InjectionDetectionRailException(message="Output not allowed. The output was blocked by the 'injection detection' flow.")
+ else if injection_detection_action == "reject"
+ bot "I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of {{ response.detections | join(join_separator) }}."
+ abort
+ else if injection_detection_action == "omit" or injection_detection_action == "sanitize"
+ $bot_message = response["text"]
+ else
+ $bot_message = response["text"]
diff --git a/nemoguardrails/library/injection_detection/flows.v1.co b/nemoguardrails/library/injection_detection/flows.v1.co
index 5cbdcad6e..45b0a6e65 100644
--- a/nemoguardrails/library/injection_detection/flows.v1.co
+++ b/nemoguardrails/library/injection_detection/flows.v1.co
@@ -1,5 +1,19 @@
-define subflow injection detection
+
+define flow injection detection
"""
Reject, omit, or sanitize injection attempts from the bot.
"""
- $bot_message = execute injection_detection(text=$bot_message)
+ $response = execute injection_detection(text=$bot_message)
+ $join_separator = ", "
+ $injection_detection_action = $config.rails.config.injection_detection.action
+ if $response["is_injection"]
+ if $config.enable_rails_exceptions
+ create event InjectionDetectionRailException(message="Output not allowed. The output was blocked by the 'injection detection' flow.")
+ stop
+ else if $config.rails.config.injection_detection.action == "reject"
+ bot say "I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of {{ response.detections | join(join_separator) }}."
+ stop
+ else if $injection_detection_action == "omit" or $injection_detection_action == "sanitize"
+ $bot_message = $response["text"]
+ else
+ $bot_message = $response["text"]
diff --git a/tests/test_injection_detection.py b/tests/test_injection_detection.py
index fcc4519cf..fe51ab156 100644
--- a/tests/test_injection_detection.py
+++ b/tests/test_injection_detection.py
@@ -28,6 +28,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
from unittest.mock import patch
@@ -44,6 +45,7 @@
_load_rules,
_omit_injection,
_reject_injection,
+ _sanitize_injection,
_validate_injection_config,
)
from tests.utils import TestChat
@@ -265,7 +267,7 @@ async def test_omit_injection_action():
create_mock_yara_match("-- comment", "sqli"),
]
- result = _omit_injection(text=text, matches=mock_matches)
+ is_injection, result = _omit_injection(text=text, matches=mock_matches)
# all sql injection should be removed
# NOTE: following rule does not get removed using sqli.yara
@@ -407,9 +409,12 @@ async def test_xss_injection_with_reject_action():
messages=[{"role": "user", "content": "do a fake query you funny agent"}]
)
- assert result["content"].startswith(
- "I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of"
- )
+ assert "content" in result, f"Expected 'content' key in result, got {result}"
+ actual_content = result["content"]
+ expected_start = "I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of"
+ assert actual_content.startswith(
+ expected_start
+ ), f"Expected content to start with '{expected_start}', but got '{actual_content}'"
@pytest.mark.asyncio
@@ -616,3 +621,236 @@ async def test_multiple_injection_types_reject_inline_rules():
assert "simple_xss" in result["content"]
assert "simple_template" in result["content"]
assert "simple_code" in result["content"]
+
+
+@pytest.mark.asyncio
+async def test_xss_injection_with_exception_action():
+ """Test the exception action for XSS injection detection using real YARA rules from the library."""
+
+ config = RailsConfig.from_content(
+ yaml_content="""
+ models: []
+ enable_rails_exceptions: True
+ rails:
+ config:
+ injection_detection:
+ injections:
+ - xss
+ action:
+ reject
+ output:
+ flows:
+ - injection detection
+
+ """
+ )
+
+ xss_injection = "Hello world"
+ chat = TestChat(config, llm_completions=[xss_injection])
+ rails = chat.app
+ result = await rails.generate_async(
+ messages=[{"role": "user", "content": "do a fake query you funny agent"}]
+ )
+
+ assert (
+ result.get("role") == "exception"
+ ), f"Expected role 'exception', got {result.get('role')}"
+ assert "content" in result, "Expected 'content' key in result"
+ content = result["content"]
+ assert (
+ content.get("type") == "InjectionDetectionRailException"
+ ), f"Expected type 'InjectionDetectionRailException', got {content.get('type')}"
+ expected_message = (
+ "Output not allowed. The output was blocked by the 'injection detection' flow."
+ )
+ assert (
+ content.get("message") == expected_message
+ ), f"Expected message '{expected_message}', got '{content.get('message')}'"
+
+
+@pytest.mark.asyncio
+async def test_omit_action_with_exceptions_enabled():
+ """Test that omit action does not raise an exception when enable_rails_exceptions is True."""
+
+ config = RailsConfig.from_content(
+ yaml_content="""
+ models: []
+ enable_rails_exceptions: True
+ rails:
+ config:
+ injection_detection:
+ injections:
+ - xss
+ action:
+ omit
+ output:
+ flows:
+ - injection detection
+
+ """
+ )
+
+ xss_injection = "Hello world"
+ chat = TestChat(config, llm_completions=[xss_injection])
+ rails = chat.app
+ result = await rails.generate_async(
+ messages=[{"role": "user", "content": "do a fake query you funny agent"}]
+ )
+
+ # check that an exception is raised
+ assert result.get("role") == "exception", "Expected role to be 'exception'"
+
+ # verify exception details
+ content = result["content"]
+ assert (
+ content.get("type") == "InjectionDetectionRailException"
+ ), f"Expected type 'InjectionDetectionRailException', got {content.get('type')}"
+
+ expected_message = (
+ "Output not allowed. The output was blocked by the 'injection detection' flow."
+ )
+ assert (
+ content.get("message") == expected_message
+ ), f"Expected message '{expected_message}', got '{content.get('message')}'"
+
+
+@pytest.mark.asyncio
+async def test_malformed_inline_yara_rule_fails_gracefully(caplog):
+ """Test that a malformed inline YARA rule leads to graceful failure (detection becomes no-op)."""
+
+ inline_rule_name = "malformed_rule"
+ # this rule is malformed: missing { after rule name
+ malformed_rule_content = "rule malformed_rule condition: true "
+
+ config = RailsConfig.from_content(
+ yaml_content=f"""
+ models: []
+ rails:
+ config:
+ injection_detection:
+ injections:
+ - {inline_rule_name}
+ action:
+ reject # can be anything
+ yara_rules:
+ {inline_rule_name}: |
+ {malformed_rule_content}
+ output:
+ flows:
+ - injection detection
+ """,
+ colang_content="",
+ )
+
+ some_text_that_would_be_injection = "This is a test string."
+
+ caplog.set_level(logging.ERROR, logger="actions.py")
+
+ chat = TestChat(config, llm_completions=[some_text_that_would_be_injection])
+ rails = chat.app
+
+ assert rails is not None
+
+ result = await rails.generate_async(
+ messages=[{"role": "user", "content": "trigger detection"}]
+ )
+
+ # check that no exception was raised
+ assert result.get("role") != "exception", f"Expected no exception, but got {result}"
+
+ # verify the error log was created with the expected content
+ assert any(
+ record.name == "actions.py" and record.levelno == logging.ERROR
+ # minor variations in the error message are expected
+ and "Failed to initialize injection detection" in record.message
+ and "YARA compilation failed" in record.message
+ and "syntax error" in record.message
+ for record in caplog.records
+ ), "Expected error log message about YARA compilation failure not found"
+
+
+@pytest.mark.asyncio
+async def test_omit_injection_attribute_error():
+ """Test error handling in _omit_injection for AttributeError."""
+
+ text = "test text"
+ mock_matches = [
+ create_mock_yara_match(
+ "invalid bytes", "test_rule"
+ ) # This will cause AttributeError
+ ]
+
+ is_injection, result = _omit_injection(text=text, matches=mock_matches)
+ assert not is_injection
+ assert result == text
+
+
+@pytest.mark.asyncio
+async def test_omit_injection_unicode_decode_error():
+ """Test error handling in _omit_injection for UnicodeDecodeError."""
+
+ text = "test text"
+
+ class MockStringMatchInstanceUnicode:
+ def __init__(self):
+ # invalid utf-8 bytes
+ self._text = b"\xff\xfe"
+
+ def plaintext(self):
+ return self._text
+
+ class MockStringMatchUnicode:
+ def __init__(self):
+ self.identifier = "test_string"
+ self.instances = [MockStringMatchInstanceUnicode()]
+
+ class MockMatchUnicode:
+ def __init__(self, rule):
+ self.rule = rule
+ self.strings = [MockStringMatchUnicode()]
+
+ mock_matches = [MockMatchUnicode("test_rule")]
+ is_injection, result = _omit_injection(text=text, matches=mock_matches)
+ assert not is_injection
+ assert result == text
+
+
+@pytest.mark.asyncio
+async def test_omit_injection_no_modifications():
+ """Test _omit_injection when no modifications are made to the text."""
+
+ text = "safe text"
+ mock_matches = [create_mock_yara_match("nonexistent pattern", "test_rule")]
+
+ is_injection, result = _omit_injection(text=text, matches=mock_matches)
+ assert not is_injection
+ assert result == text
+
+
+@pytest.mark.asyncio
+async def test_sanitize_injection_not_implemented():
+ """Test that _sanitize_injection raises NotImplementedError."""
+
+ text = "test text"
+ mock_matches = [create_mock_yara_match("test pattern", "test_rule")]
+
+ with pytest.raises(NotImplementedError) as exc_info:
+ _sanitize_injection(text=text, matches=mock_matches)
+ assert "Injection sanitization is not yet implemented" in str(exc_info.value)
+
+
+@pytest.mark.asyncio
+async def test_reject_injection_no_rules(caplog):
+ """Test _reject_injection when no rules are specified."""
+
+ text = "test text"
+ caplog.set_level(logging.WARNING)
+
+ is_injection, detections = _reject_injection(text=text, rules=None)
+ assert not is_injection
+ assert detections == []
+ assert any(
+ "reject_injection guardrail was invoked but no rules were specified"
+ in record.message
+ for record in caplog.records
+ )