8000 connection/aws_ssm - create TerminalManager class and move related methods by mandar242 · Pull Request #2270 · ansible-collections/community.aws · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

connection/aws_ssm - create TerminalManager class and move related methods #2270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - Refactor connection/aws_ssm to add new TerminalManager class and move relevant methods to the new class (https://github.com/ansible-collections/community.aws/pull/2270).
93 changes: 4 additions & 89 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,12 @@
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.process import get_bin_path
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.shell.powershell import _common_args
from ansible.utils.display import Display

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager
from ansible_collections.community.aws.plugins.plugin_utils.terminalmanager import TerminalManager

display = Display()

Expand Down Expand Up @@ -484,6 +484,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._instance_id = None
self._polling_obj = None
self._has_timeout = False
self.terminal_manager = TerminalManager(self)

if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
self.delegate = None
Expand Down Expand Up @@ -645,7 +646,7 @@ def start_session(self):
self._stdout = os.fdopen(stdout_r, "rb", 0)

# For non-windows Hosts: Ensure the session has started, and disable command echo and prompt.
self._prepare_terminal()
self.terminal_manager.prepare_terminal()

self.verbosity_display(4, f"SSM CONNECTION ID: {self._session_id}") # pylint: disable=unreachable

Expand Down Expand Up @@ -743,7 +744,7 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) ->
mark_end = self.generate_mark()

# Wrap command in markers accordingly for the shell used
cmd = self._wrap_command(cmd, mark_start, mark_end)
cmd = self.terminal_manager.wrap_command(cmd, mark_start, mark_end)

self._flush_stderr(self._session)

Expand All @@ -752,92 +753,6 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) ->

return self.exec_communicate(cmd, mark_start, mark_begin, mark_end)

def _ensure_ssm_session_has_started(self) -> None:
"""Ensure the SSM session has started on the host. We poll stdout
until we match the following string 'Starting session with SessionId'
"""
stdout = ""
for poll_result in self.poll("START SSM SESSION", "start_session"):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self.verbosity_display(4, "START SSM SESSION startup output received")
break

def _disable_prompt_command(self) -> None:
"""Disable prompt command from the host"""
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
disable_prompt_cmd = to_bytes(
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
)
disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE)

# Send command
self.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}")
self._session.stdin.write(disable_prompt_cmd)

stdout = ""
for poll_result in self.poll("DISABLE PROMPT", disable_prompt_cmd):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}")
if disable_prompt_reply.search(stdout):
break

def _disable_echo_command(self) -> None:
"""Disable echo command from the host"""
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")

# Send command
self.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}")
self._session.stdin.write(disable_echo_cmd)

stdout = ""
for poll_result in self.poll("DISABLE ECHO", disable_echo_cmd):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("stty -echo")
if match != -1:
break

def _prepare_terminal(self) -> None:
"""perform any one-time terminal settings"""
# No Windows setup for now
if self.is_windows:
return

# Ensure SSM Session has started
self._ensure_ssm_session_has_started()

# Disable echo command
self._disable_echo_command() # pylint: disable=unreachable

# Disable prompt command
self._disable_prompt_command() # pylint: disable=unreachable

self.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable

def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""

if self.is_windows:
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
cmd = self._shell._encode_script(cmd, preserve_rc=True)
cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n"
else:
cmd = (
f"printf '%s\\n' '{mark_start}';\n"
f"echo | {cmd};\n"
f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n"
) # fmt: skip

self.verbosity_display(4, f"_wrap_command: \n'{to_text(cmd)}'")
return cmd

def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]:
"""extract command status and strip unwanted lines"""

Expand Down
103 changes: 103 additions & 0 deletions plugins/plugin_utils/terminalmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-

# Copyright: Contributors to the Ansible project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

import random
import re
import string

from ansible.module_utils._text import to_bytes
from ansible.module_utils._text import to_text
from ansible.plugins.shell.powershell import _common_args


class TerminalManager:
def __init__(self, connection):
self.connection = connection

def prepare_terminal(self) -> None:
"""perform any one-time terminal settings"""
# No Windows setup for now
if self.connection.is_windows:
return

# Ensure SSM Session has started
self.ensure_ssm_session_has_started()

# Disable echo command
self.disable_echo_command() # pylint: disable=unreachable

# Disable prompt command
self.disable_prompt_command() # pylint: disable=unreachable

self.connection.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable

def wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""

if self.connection.is_windows:
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
cmd = self.connection._shell._encode_script(cmd, preserve_rc=True)
cmd = f"{cmd}; echo {mark_start}\necho {mark_end}\n"
else:
cmd = (
f"printf '%s\\n' '{mark_start}';\n"
f"echo | {cmd};\n"
f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n"
) # fmt: skip

self.connection.verbosity_display(4, f"wrap_command: \n'{to_text(cmd)}'")
return cmd

def disable_echo_command(self) -> None:
"""Disable echo command from the host"""
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")

# Send command
self.connection.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}")
self.connection._session.stdin.write(disable_echo_cmd)

stdout = ""
for poll_result in self.connection.poll("DISABLE ECHO", disable_echo_cmd):
if poll_result:
stdout += to_text(self.connection._stdout.read(1024))
self.connection.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("stty -echo")
if match != -1:
break

def disable_prompt_command(self) -> None:
"""Disable prompt command from the host"""
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.connection.MARK_LENGTH)])
disable_prompt_cmd = to_bytes(
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
)
disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE)

# Send command
self.connection.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}")
self.connection._session.stdin.write(disable_prompt_cmd)

stdout = ""
for poll_result in self.connection.poll("DISABLE PROMPT", disable_prompt_cmd):
if poll_result:
stdout += to_text(self.connection._stdout.read(1024))
self.connection.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}")
if disable_prompt_reply.search(stdout):
break

def ensure_ssm_session_has_started(self) -> None:
"""Ensure the SSM session has started on the host. We poll stdout
until we match the following string 'Starting session with SessionId'
"""
stdout = ""
for poll_result in self.connection.poll("START SSM SESSION", "start_session"):
if poll_result:
stdout += to_text(self.connection._stdout.read(1024))
self.connection.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self.connection.verbosity_display(4, "START SSM SESSION startup output received")
break
2 changes: 2 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/test_exec_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def test_connection_aws_ssm_exec_command(m_chunks, connection_aws_ssm, is_window
cmd = MagicMock()
in_data = MagicMock()
sudoable = MagicMock()
connection_aws_ssm.terminal_manager = MagicMock()

assert result == connection_aws_ssm.exec_command(cmd, in_data, sudoable)
# m_chunks.assert_called_once_with(chunk, 1024)
connection_aws_ssm._flush_stderr.assert_called_once_with(connection_aws_ssm._session)
33 changes: 22 additions & 11 deletions tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

from ansible_collections.community.aws.plugins.connection.aws_ssm import TerminalManager

if not HAS_BOTO3:
pytestmark = pytest.mark.skip("test_poll.py requires the python modules 'boto3' and 'botocore'")

Expand Down Expand Up @@ -47,15 +49,18 @@ def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ss

connection_aws_ssm._stdout.read.side_effect = stdout_lines

if not hasattr(connection_aws_ssm, "terminal_manager"):
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)

poll_mock.results = [True for i in range(len(stdout_lines))]
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._ensure_ssm_session_has_started()
connection_aws_ssm.terminal_manager.ensure_ssm_session_has_started()
else:
connection_aws_ssm._ensure_ssm_session_has_started()
connection_aws_ssm.terminal_manager.ensure_ssm_session_has_started()


@pytest.mark.parametrize(
Expand All @@ -67,8 +72,8 @@ def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ss
(["stty ", "-ech"], True),
],
)
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_bytes")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_text")
def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = lambda x, **kw: str(x)
Expand All @@ -80,19 +85,22 @@ def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if not hasattr(connection_aws_ssm, "terminal_manager"):
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._disable_echo_command()
connection_aws_ssm.terminal_manager.disable_echo_command()
else:
connection_aws_ssm._disable_echo_command()
connection_aws_ssm.terminal_manager.disable_echo_command()

connection_aws_ssm._session.stdin.write.assert_called_once_with("stty -echo\n")


@pytest.mark.parametrize("timeout_failure", [True, False])
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.random")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.random")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_bytes")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_text")
def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_ssm, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = lambda x, **kw: str(x)
Expand All @@ -101,6 +109,9 @@ def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if not hasattr(connection_aws_ssm, "terminal_manager"):
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)

m_random.choice = MagicMock()
m_random.choice.side_effect = lambda x: "a"

Expand All @@ -115,8 +126,8 @@ def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._disable_prompt_command()
connection_aws_ssm.terminal_manager.disable_prompt_command()
else:
connection_aws_ssm._disable_prompt_command()
connection_aws_ssm.terminal_manager.disable_prompt_command()

connection_aws_ssm._session.stdin.write.assert_called_once_with(prompt_cmd)
Loading
0