diff --git a/dvc_ssh/client.py b/dvc_ssh/client.py index aacb185..32328ef 100644 --- a/dvc_ssh/client.py +++ b/dvc_ssh/client.py @@ -95,7 +95,10 @@ async def _read_private_key_interactive(self, path: "FilePath") -> "SSHKey": pass raise KeyImportError("Incorrect passphrase") - def kbdint_auth_requested(self) -> str: + def kbdint_auth_requested(self): + if self._conn._options.password is not None: + return NotImplemented + return "" async def kbdint_challenge_received( diff --git a/dvc_ssh/tests/cloud.py b/dvc_ssh/tests/cloud.py index de3c7a6..367a9e9 100644 --- a/dvc_ssh/tests/cloud.py +++ b/dvc_ssh/tests/cloud.py @@ -7,6 +7,7 @@ from dvc.testing.path_info import URLInfo TEST_SSH_USER = "user" +TEST_SSH_PASSWORD = "password" TEST_SSH_KEY_PATH = os.path.join( os.path.abspath(os.path.dirname(__file__)), f"{TEST_SSH_USER}.key" ) diff --git a/dvc_ssh/tests/docker-compose.yml b/dvc_ssh/tests/docker-compose.yml index ec5bcb0..ceb1895 100644 --- a/dvc_ssh/tests/docker-compose.yml +++ b/dvc_ssh/tests/docker-compose.yml @@ -4,7 +4,9 @@ services: openssh-server: image: ghcr.io/linuxserver/openssh-server environment: + - PASSWORD_ACCESS=true - USER_NAME=user + - USER_PASSWORD=password - PUBLIC_KEY_FILE=/tmp/key ports: - 2222 diff --git a/dvc_ssh/tests/test_client.py b/dvc_ssh/tests/test_client.py new file mode 100644 index 0000000..aabb9eb --- /dev/null +++ b/dvc_ssh/tests/test_client.py @@ -0,0 +1,43 @@ +from types import SimpleNamespace + +import pytest + +import dvc_ssh.client +from dvc_ssh import SSHFileSystem +from dvc_ssh.client import InteractiveSSHClient +from dvc_ssh.tests.cloud import TEST_SSH_PASSWORD, TEST_SSH_USER + + +@pytest.mark.parametrize( + "password,expected", + [("secret", NotImplemented), (None, "")], +) +def test_kbdint_auth_requested(password, expected): + client = InteractiveSSHClient() + client._conn = SimpleNamespace(_options=SimpleNamespace(password=password)) + + result = client.kbdint_auth_requested() + + if expected is NotImplemented: + assert result is NotImplemented + else: + assert result == expected + + +def test_password_auth_uses_configured_password(ssh_server, monkeypatch, tmp_path): + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.delenv("SSH_AUTH_SOCK", raising=False) + + async def fail_getpass(*args, **kwargs): + raise AssertionError("_getpass should not be called") + + monkeypatch.setattr(dvc_ssh.client, "_getpass", fail_getpass) + + fs = SSHFileSystem( + host=ssh_server["host"], + port=ssh_server["port"], + user=TEST_SSH_USER, + password=TEST_SSH_PASSWORD, + ) + + assert fs.exists("/tmp")