diff --git a/nemoguardrails/library/injection_detection/actions.py b/nemoguardrails/library/injection_detection/actions.py index 947b55d37..7a85e2993 100644 --- a/nemoguardrails/library/injection_detection/actions.py +++ b/nemoguardrails/library/injection_detection/actions.py @@ -32,7 +32,7 @@ import re from functools import lru_cache from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, TypedDict, Union yara = None try: @@ -49,6 +49,12 @@ log = logging.getLogger(__name__) +class InjectionDetectionResult(TypedDict): + is_injection: bool + text: str + detections: List[str] + + def _check_yara_available(): if yara is None: raise ImportError( @@ -197,13 +203,13 @@ def _load_rules( } rules = yara.compile(filepaths=rules_to_load) except yara.SyntaxError as e: - msg = f"Encountered SyntaxError: {e}" + msg = f"Failed to initialize injection detection due to configuration or YARA rule error: YARA compilation failed: {e}" log.error(msg) - raise e + return None return rules -def _omit_injection(text: str, matches: list["yara.Match"]) -> str: +def _omit_injection(text: str, matches: list["yara.Match"]) -> Tuple[bool, str]: """ Attempts to strip the offending injection attempts from the provided text. @@ -216,14 +222,18 @@ def _omit_injection(text: str, matches: list["yara.Match"]) -> str: matches (list['yara.Match']): A list of YARA rule matches. Returns: - str: The text with the detected injections stripped out. + Tuple[bool, str]: A tuple containing: + - bool: True if injection was detected and modified, + False if the text is safe (i.e., not modified). + - str: The text, with detected injections stripped out if modified. Raises: ImportError: If the yara module is not installed. """ - # Copy the text to a placeholder variable + original_text = text modified_text = text + is_injection = False for match in matches: if match.strings: for match_string in match.strings: @@ -234,10 +244,16 @@ def _omit_injection(text: str, matches: list["yara.Match"]) -> str: modified_text = modified_text.replace(plaintext, "") except (AttributeError, UnicodeDecodeError) as e: log.warning(f"Error processing match: {e}") - return modified_text + + if modified_text != original_text: + is_injection = True + return is_injection, modified_text + else: + is_injection = False + return is_injection, original_text -def _sanitize_injection(text: str, matches: list["yara.Match"]) -> str: +def _sanitize_injection(text: str, matches: list["yara.Match"]) -> Tuple[bool, str]: """ Attempts to sanitize the offending injection attempts in the provided text. This is done by 'de-fanging' the offending content, transforming it into a state that will not execute @@ -253,19 +269,27 @@ def _sanitize_injection(text: str, matches: list["yara.Match"]) -> str: matches (list['yara.Match']): A list of YARA rule matches. Returns: - str: The text with the detected injections sanitized. + Tuple[bool, str]: A tuple containing: + - bool: True if injection was detected, False otherwise. + - str: The sanitized text, or original text depending on sanitization outcome. + Currently, this function will always raise NotImplementedError. Raises: NotImplementedError: If the sanitization logic is not implemented. ImportError: If the yara module is not installed. """ - raise NotImplementedError( "Injection sanitization is not yet implemented. Please use 'reject' or 'omit'" ) + # Hypothetical logic if implemented, to match existing behavior in injection_detection: + # sanitized_text_attempt = "..." # result of sanitization + # if sanitized_text_attempt != text: + # return True, text # Original text returned, marked as injection detected + # else: + # return False, sanitized_text_attempt -def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, str]: +def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, List[str]]: """ Detects whether the provided text contains potential injection attempts. @@ -277,8 +301,9 @@ def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, str]: rules ('yara.Rules'): The loaded YARA rules. Returns: - bool: True if attempted exploitation is detected, False otherwise. - str: list of matches as a string + Tuple[bool, List[str]]: A tuple containing: + - bool: True if attempted exploitation is detected, False otherwise. + - List[str]: List of matched rule names. Raises: ValueError: If the `action` parameter in the configuration is invalid. @@ -289,18 +314,20 @@ def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, str]: log.warning( "reject_injection guardrail was invoked but no rules were specified in the InjectionDetection config." ) - return False, "" + return False, [] matches = rules.match(data=text) if matches: - matches_string = ", ".join([match_name.rule for match_name in matches]) - log.info(f"Input matched on rule {matches_string}.") - return True, matches_string + matched_rules = [match_name.rule for match_name in matches] + log.info(f"Input matched on rule {', '.join(matched_rules)}.") + return True, matched_rules else: - return False, "" + return False, [] @action() -async def injection_detection(text: str, config: RailsConfig) -> str: +async def injection_detection( + text: str, config: RailsConfig +) -> InjectionDetectionResult: """ Detects and mitigates potential injection attempts in the provided text. @@ -310,45 +337,68 @@ async def injection_detection(text: str, config: RailsConfig) -> str: Args: text (str): The text to check for command injection. + config (RailsConfig): The Rails configuration object containing injection detection settings. Returns: - str: The sanitized or original text, depending on the action specified in the configuration. + InjectionDetectionResult: A TypedDict containing: + - is_injection (bool): Whether an injection was detected. True if any injection is detected, + False if no injection is detected. + - text (str): The sanitized or original text + - detections (List[str]): List of matched rule names if any injection is detected Raises: ValueError: If the `action` parameter in the configuration is invalid. NotImplementedError: If an unsupported action is encountered. + ImportError: If the yara module is not installed. """ _check_yara_available() _validate_injection_config(config) + action_option, yara_path, rule_names, yara_rules = _extract_injection_config(config) rules = _load_rules(yara_path, rule_names, yara_rules) - if action_option == "reject": - verdict, detections = _reject_injection(text, rules) - if verdict: - return f"I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of {detections}." - else: - return text if rules is None: log.warning( "injection detection guardrail was invoked but no rules were specified in the InjectionDetection config." ) - return text - matches = rules.match(data=text) - if matches: - matches_string = ", ".join([match_name.rule for match_name in matches]) - log.info(f"Input matched on rule {matches_string}.") - if action_option == "omit": - return _omit_injection(text, matches) - elif action_option == "sanitize": - return _sanitize_injection(text, matches) + return InjectionDetectionResult(is_injection=False, text=text, detections=[]) + + if action_option == "reject": + is_injection, detected_rules = _reject_injection(text, rules) + return InjectionDetectionResult( + is_injection=is_injection, text=text, detections=detected_rules + ) + else: + matches = rules.match(data=text) + if matches: + detected_rules_list = [match_name.rule for match_name in matches] + log.info(f"Input matched on rule {', '.join(detected_rules_list)}.") + + if action_option == "omit": + is_injection, result_text = _omit_injection(text, matches) + return InjectionDetectionResult( + is_injection=is_injection, + text=result_text, + detections=detected_rules_list, + ) + elif action_option == "sanitize": + # _sanitize_injection will raise NotImplementedError before returning a tuple. + # the assignment below is for structural consistency if it were implemented. + is_injection, result_text = _sanitize_injection(text, matches) + return InjectionDetectionResult( + is_injection=is_injection, + text=result_text, + detections=detected_rules_list, + ) + else: + raise NotImplementedError( + f"Expected `action` parameter to be 'reject', 'omit', or 'sanitize' but got {action_option} instead." + ) + # no matches found else: - # We should never ever hit this since we inspect the action option above, but putting an error here anyway. - raise NotImplementedError( - f"Expected `action` parameter to be 'omit' or 'sanitize' but got {action_option} instead." + return InjectionDetectionResult( + is_injection=False, text=text, detections=[] ) - else: - return text diff --git a/nemoguardrails/library/injection_detection/flows.co b/nemoguardrails/library/injection_detection/flows.co index 22ca9095f..26da02578 100644 --- a/nemoguardrails/library/injection_detection/flows.co +++ b/nemoguardrails/library/injection_detection/flows.co @@ -1,7 +1,19 @@ -# OUTPUT RAILS - flow injection detection """ Reject, omit, or sanitize injection attempts from the bot. + This rail operates on the $bot_message. """ - $bot_message = await InjectionDetectionAction(text=$bot_message) + response = await InjectionDetectionAction(text=$bot_message) + join_separator = ", " + injection_detection_action = $config.rails.config.injection_detection.action + + if response["is_injection"] + if $config.enable_rails_exceptions + send InjectionDetectionRailException(message="Output not allowed. The output was blocked by the 'injection detection' flow.") + else if injection_detection_action == "reject" + bot "I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of {{ response.detections | join(join_separator) }}." + abort + else if injection_detection_action == "omit" or injection_detection_action == "sanitize" + $bot_message = response["text"] + else + $bot_message = response["text"] diff --git a/nemoguardrails/library/injection_detection/flows.v1.co b/nemoguardrails/library/injection_detection/flows.v1.co index 5cbdcad6e..45b0a6e65 100644 --- a/nemoguardrails/library/injection_detection/flows.v1.co +++ b/nemoguardrails/library/injection_detection/flows.v1.co @@ -1,5 +1,19 @@ -define subflow injection detection + +define flow injection detection """ Reject, omit, or sanitize injection attempts from the bot. """ - $bot_message = execute injection_detection(text=$bot_message) + $response = execute injection_detection(text=$bot_message) + $join_separator = ", " + $injection_detection_action = $config.rails.config.injection_detection.action + if $response["is_injection"] + if $config.enable_rails_exceptions + create event InjectionDetectionRailException(message="Output not allowed. The output was blocked by the 'injection detection' flow.") + stop + else if $config.rails.config.injection_detection.action == "reject" + bot say "I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of {{ response.detections | join(join_separator) }}." + stop + else if $injection_detection_action == "omit" or $injection_detection_action == "sanitize" + $bot_message = $response["text"] + else + $bot_message = $response["text"] diff --git a/tests/test_injection_detection.py b/tests/test_injection_detection.py index fcc4519cf..fe51ab156 100644 --- a/tests/test_injection_detection.py +++ b/tests/test_injection_detection.py @@ -28,6 +28,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from unittest.mock import patch @@ -44,6 +45,7 @@ _load_rules, _omit_injection, _reject_injection, + _sanitize_injection, _validate_injection_config, ) from tests.utils import TestChat @@ -265,7 +267,7 @@ async def test_omit_injection_action(): create_mock_yara_match("-- comment", "sqli"), ] - result = _omit_injection(text=text, matches=mock_matches) + is_injection, result = _omit_injection(text=text, matches=mock_matches) # all sql injection should be removed # NOTE: following rule does not get removed using sqli.yara @@ -407,9 +409,12 @@ async def test_xss_injection_with_reject_action(): messages=[{"role": "user", "content": "do a fake query you funny agent"}] ) - assert result["content"].startswith( - "I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of" - ) + assert "content" in result, f"Expected 'content' key in result, got {result}" + actual_content = result["content"] + expected_start = "I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of" + assert actual_content.startswith( + expected_start + ), f"Expected content to start with '{expected_start}', but got '{actual_content}'" @pytest.mark.asyncio @@ -616,3 +621,236 @@ async def test_multiple_injection_types_reject_inline_rules(): assert "simple_xss" in result["content"] assert "simple_template" in result["content"] assert "simple_code" in result["content"] + + +@pytest.mark.asyncio +async def test_xss_injection_with_exception_action(): + """Test the exception action for XSS injection detection using real YARA rules from the library.""" + + config = RailsConfig.from_content( + yaml_content=""" + models: [] + enable_rails_exceptions: True + rails: + config: + injection_detection: + injections: + - xss + action: + reject + output: + flows: + - injection detection + + """ + ) + + xss_injection = "Hello world" + chat = TestChat(config, llm_completions=[xss_injection]) + rails = chat.app + result = await rails.generate_async( + messages=[{"role": "user", "content": "do a fake query you funny agent"}] + ) + + assert ( + result.get("role") == "exception" + ), f"Expected role 'exception', got {result.get('role')}" + assert "content" in result, "Expected 'content' key in result" + content = result["content"] + assert ( + content.get("type") == "InjectionDetectionRailException" + ), f"Expected type 'InjectionDetectionRailException', got {content.get('type')}" + expected_message = ( + "Output not allowed. The output was blocked by the 'injection detection' flow." + ) + assert ( + content.get("message") == expected_message + ), f"Expected message '{expected_message}', got '{content.get('message')}'" + + +@pytest.mark.asyncio +async def test_omit_action_with_exceptions_enabled(): + """Test that omit action does not raise an exception when enable_rails_exceptions is True.""" + + config = RailsConfig.from_content( + yaml_content=""" + models: [] + enable_rails_exceptions: True + rails: + config: + injection_detection: + injections: + - xss + action: + omit + output: + flows: + - injection detection + + """ + ) + + xss_injection = "Hello world" + chat = TestChat(config, llm_completions=[xss_injection]) + rails = chat.app + result = await rails.generate_async( + messages=[{"role": "user", "content": "do a fake query you funny agent"}] + ) + + # check that an exception is raised + assert result.get("role") == "exception", "Expected role to be 'exception'" + + # verify exception details + content = result["content"] + assert ( + content.get("type") == "InjectionDetectionRailException" + ), f"Expected type 'InjectionDetectionRailException', got {content.get('type')}" + + expected_message = ( + "Output not allowed. The output was blocked by the 'injection detection' flow." + ) + assert ( + content.get("message") == expected_message + ), f"Expected message '{expected_message}', got '{content.get('message')}'" + + +@pytest.mark.asyncio +async def test_malformed_inline_yara_rule_fails_gracefully(caplog): + """Test that a malformed inline YARA rule leads to graceful failure (detection becomes no-op).""" + + inline_rule_name = "malformed_rule" + # this rule is malformed: missing { after rule name + malformed_rule_content = "rule malformed_rule condition: true " + + config = RailsConfig.from_content( + yaml_content=f""" + models: [] + rails: + config: + injection_detection: + injections: + - {inline_rule_name} + action: + reject # can be anything + yara_rules: + {inline_rule_name}: | + {malformed_rule_content} + output: + flows: + - injection detection + """, + colang_content="", + ) + + some_text_that_would_be_injection = "This is a test string." + + caplog.set_level(logging.ERROR, logger="actions.py") + + chat = TestChat(config, llm_completions=[some_text_that_would_be_injection]) + rails = chat.app + + assert rails is not None + + result = await rails.generate_async( + messages=[{"role": "user", "content": "trigger detection"}] + ) + + # check that no exception was raised + assert result.get("role") != "exception", f"Expected no exception, but got {result}" + + # verify the error log was created with the expected content + assert any( + record.name == "actions.py" and record.levelno == logging.ERROR + # minor variations in the error message are expected + and "Failed to initialize injection detection" in record.message + and "YARA compilation failed" in record.message + and "syntax error" in record.message + for record in caplog.records + ), "Expected error log message about YARA compilation failure not found" + + +@pytest.mark.asyncio +async def test_omit_injection_attribute_error(): + """Test error handling in _omit_injection for AttributeError.""" + + text = "test text" + mock_matches = [ + create_mock_yara_match( + "invalid bytes", "test_rule" + ) # This will cause AttributeError + ] + + is_injection, result = _omit_injection(text=text, matches=mock_matches) + assert not is_injection + assert result == text + + +@pytest.mark.asyncio +async def test_omit_injection_unicode_decode_error(): + """Test error handling in _omit_injection for UnicodeDecodeError.""" + + text = "test text" + + class MockStringMatchInstanceUnicode: + def __init__(self): + # invalid utf-8 bytes + self._text = b"\xff\xfe" + + def plaintext(self): + return self._text + + class MockStringMatchUnicode: + def __init__(self): + self.identifier = "test_string" + self.instances = [MockStringMatchInstanceUnicode()] + + class MockMatchUnicode: + def __init__(self, rule): + self.rule = rule + self.strings = [MockStringMatchUnicode()] + + mock_matches = [MockMatchUnicode("test_rule")] + is_injection, result = _omit_injection(text=text, matches=mock_matches) + assert not is_injection + assert result == text + + +@pytest.mark.asyncio +async def test_omit_injection_no_modifications(): + """Test _omit_injection when no modifications are made to the text.""" + + text = "safe text" + mock_matches = [create_mock_yara_match("nonexistent pattern", "test_rule")] + + is_injection, result = _omit_injection(text=text, matches=mock_matches) + assert not is_injection + assert result == text + + +@pytest.mark.asyncio +async def test_sanitize_injection_not_implemented(): + """Test that _sanitize_injection raises NotImplementedError.""" + + text = "test text" + mock_matches = [create_mock_yara_match("test pattern", "test_rule")] + + with pytest.raises(NotImplementedError) as exc_info: + _sanitize_injection(text=text, matches=mock_matches) + assert "Injection sanitization is not yet implemented" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_reject_injection_no_rules(caplog): + """Test _reject_injection when no rules are specified.""" + + text = "test text" + caplog.set_level(logging.WARNING) + + is_injection, detections = _reject_injection(text=text, rules=None) + assert not is_injection + assert detections == [] + assert any( + "reject_injection guardrail was invoked but no rules were specified" + in record.message + for record in caplog.records + )