From 59164fb71031e85c92e852881868f0b4e85c962b Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Tue, 21 Dec 2021 17:01:26 -0500 Subject: [PATCH] Don't set `shell=True` with untrusted input Previously `mssh` would blindly execute an SSH command, resulting in shell pipelines being executed on the *host* rather on the SSH target. Consider the following command: $ mssh i-04bb8a432b18b2250 'whoami; whoami' ubuntu benesch The second invocation of "whoami" runs on the host and therefore prints my local username, rather than the username on the EC2 instance. This is at odds with the normal SSH program, which would print "ubuntu" for both, as any shell metacharacters are left to be interpreted by the remote shell. This issue was previously reported as #24, with a proposed fix in #25 that simply shell quotes the command. That solution seems suboptimal to me, as it is generally a bad idea to pass user input to a shell. This commit solves the issue another way, by keeping track of individual arguments as we go. Rather than building up a command string like "ssh ubuntu@10.0.0.1 USER-FLAGS USER-COMMAND" and then passing that to the local shell for interpretation, we instead build up a command array like: ["ssh", "ubuntu@10.0.0.1", "USER-FLAG-1", "USER-FLAG-2", "USER-COMMAND"] This command can be executed without invoking the shell, and so we can be sure it will not execute any code on the host. Fix #24. --- .../EC2InstanceConnectCLI.py | 12 ++--- .../EC2InstanceConnectCommand.py | 11 ++--- ec2instanceconnectcli/input_parser.py | 6 +-- tests/test_EC2ConnectCLI.py | 49 ++++++++++--------- tests/test_input_parser.py | 12 ++--- 5 files changed, 44 insertions(+), 46 deletions(-) diff --git a/ec2instanceconnectcli/EC2InstanceConnectCLI.py b/ec2instanceconnectcli/EC2InstanceConnectCLI.py index 7d7d9f3..b669326 100644 --- a/ec2instanceconnectcli/EC2InstanceConnectCLI.py +++ b/ec2instanceconnectcli/EC2InstanceConnectCLI.py @@ -81,18 +81,18 @@ def handle_keys(self): key_publisher.push_public_key(session, bundle['instance_id'], bundle['username'], self.pub_key, bundle['zone']) self.logger.debug('Successfully pushed the public key to {0}'.format(bundle['instance_id'])) - def run_command(self, command=None): + def run_command(self, args=None): """ - Runs the given command in a sub-shell - :param command: Command to invoke - :type command: basestring + Runs the given command + :param args: Arguments to invoke + :type args: list of strings :return: Return code for remote command :rtype: int """ - if not command: + if not args: raise ValueError('Must provide a command') - invocation_proc = Popen(command, shell=True) + invocation_proc = Popen(args) while invocation_proc.poll() is None: #sub-process not terminated time.sleep(0.1) return invocation_proc.returncode diff --git a/ec2instanceconnectcli/EC2InstanceConnectCommand.py b/ec2instanceconnectcli/EC2InstanceConnectCommand.py index 32de406..b0a12c2 100644 --- a/ec2instanceconnectcli/EC2InstanceConnectCommand.py +++ b/ec2instanceconnectcli/EC2InstanceConnectCommand.py @@ -43,21 +43,20 @@ def get_command(self): Generates and returns the generated command """ # Start with protocol & identity file - command = '{0} -o "IdentitiesOnly=yes" -i {1}'.format(self.program, self.key_file) + command = [self.program, '-o', 'IdentitiesOnly=yes', '-i', self.key_file] # Next add command flags if present - if len(self.flags) > 0: - command = "{0} {1}".format(command, self.flags) + command.extend(self.flags) # Target - command = "{0} {1}".format(command, self._get_target(self.instance_bundles[0])) + command.append(self._get_target(self.instance_bundles[0])) #program specific command if len(self.program_command) > 0: - command = "{0} {1}".format(command, self.program_command) + command.append(self.program_command) if len(self.instance_bundles) > 1: - command = "{0} {1}".format(command, self._get_target(self.instance_bundles[1])) + command.append(self._get_target(self.instance_bundles[1])) self.logger.debug('Generated command: {0}'.format(command)) diff --git a/ec2instanceconnectcli/input_parser.py b/ec2instanceconnectcli/input_parser.py index 705d214..e53df0c 100644 --- a/ec2instanceconnectcli/input_parser.py +++ b/ec2instanceconnectcli/input_parser.py @@ -117,7 +117,7 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False): :return: tuple of flags and final comamnd or file list :rtype: tuple """ - flags = '' + flags = [] is_user = False is_flagged = False command_index = 0 @@ -133,7 +133,7 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False): used += 1 # This is either a flag or a flag value - flags = '{0} {1}'.format(flags, raw_command[command_index]) + flags.append(raw_command[command_index]) if raw_command[command_index][0] == '-': # Flag @@ -152,8 +152,6 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False): command_index += 1 - flags = flags.strip() - """ Target host and command or file list """ diff --git a/tests/test_EC2ConnectCLI.py b/tests/test_EC2ConnectCLI.py index 584a4e9..f59117e 100644 --- a/tests/test_EC2ConnectCLI.py +++ b/tests/test_EC2ConnectCLI.py @@ -31,7 +31,7 @@ def test_mssh_no_target(self, mock_push_key, mock_run): mock_file = 'identity' - flag = '-f flag' + flags = ['-f', 'flag'] command = 'command arg' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -41,12 +41,12 @@ def test_mssh_no_target(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - - expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user, - self.public_ip, command) + + expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}'.format(self.default_user, self.public_ip), command] # Check that we successfully get to the run self.assertTrue(mock_instance_data.called) @@ -62,7 +62,7 @@ def test_mssh_no_target_no_public_ip(self, mock_push_key, mock_run): mock_file = "identity" - flag = '-f flag' + flags = ['-f', 'flag'] command = 'command arg' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -72,12 +72,12 @@ def test_mssh_no_target_no_public_ip(self, mock_instance_data.return_value = self.private_instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user, - self.private_ip, command) + expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}'.format(self.default_user, self.private_ip), command] # Check that we successfully get to the run self.assertTrue(mock_instance_data.called) @@ -92,7 +92,7 @@ def test_mssh_with_target(self, mock_push_key, mock_run): mock_file = 'identity' - flag = '-f flag' + flags = ['-f', 'flag'] command = 'command arg' host = '0.0.0.0' logger = EC2InstanceConnectLogger() @@ -103,12 +103,12 @@ def test_mssh_with_target(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user, - host, command) + expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}'.format(self.default_user, host), command] # Check that we successfully get to the run # Since both target and availability_zone are provided, mock_instance_data should not be called self.assertFalse(mock_instance_data.called) @@ -123,7 +123,7 @@ def test_msftp(self, mock_push_key, mock_run): mock_file = 'identity' - flag = '-f flag' + flags = ['-f', 'flag'] command = 'file2 file3' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -133,10 +133,11 @@ def test_msftp(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - expected_command = 'sftp -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3}:{4} {5}'.format(mock_file, flag, self.default_user, - self.public_ip, 'file1', command) + expected_command = ['sftp', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}:{}'.format(self.default_user, self.public_ip, 'file1'), + command] - cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() @@ -153,7 +154,7 @@ def test_mscp(self, mock_push_key, mock_run): mock_file = 'identity' - flag = '-f flag' + flags = ['-f', 'flag'] command = 'file2 file3' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -166,12 +167,12 @@ def test_mscp(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - expected_command = 'scp -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3}:{4} {5} {6}@{7}:{8}'.format(mock_file, flag, self.default_user, - self.public_ip, 'file1', command, - self.default_user, - self.public_ip, 'file4') + expected_command = ['scp', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}:{}'.format(self.default_user, self.public_ip, 'file1'), + command, + '{}@{}:{}'.format(self.default_user, self.public_ip, 'file4')] - cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() @@ -183,5 +184,5 @@ def test_mscp(self, def test_status_code(self): #TODO: Refine test for checking run_command status code cli = EC2InstanceConnectCLI(None, None, None, None) - code = cli.run_command("echo ok; exit -1;") + code = cli.run_command(["sh", "-c", "echo ok; exit -1;"]) self.assertEqual(code, 255) diff --git a/tests/test_input_parser.py b/tests/test_input_parser.py index e04fde9..b7ca3ff 100644 --- a/tests/test_input_parser.py +++ b/tests/test_input_parser.py @@ -41,7 +41,7 @@ def test_basic_target(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, '') def test_username(self): @@ -51,7 +51,7 @@ def test_username(self): self.assertEqual(bundles, [{'username': 'myuser', 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, '') def test_dns_name(self): @@ -63,7 +63,7 @@ def test_dns_name(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': self.dns_name, 'zone': self.availability_zone, 'region': self.region, 'profile': self.profile}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, '') def test_flags(self): @@ -73,7 +73,7 @@ def test_flags(self): self.assertEqual(bundles, [{'username': 'login', 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) - self.assertEqual(flags, '-1 -l login') + self.assertEqual(flags, ['-1', '-l', 'login']) self.assertEqual(command, '') def test_command(self): @@ -83,7 +83,7 @@ def test_command(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, 'uname -a') def test_sftp(self): @@ -95,7 +95,7 @@ def test_sftp(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile, 'file': 'first_file'}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, 'second_file') def test_invalid_username(self):