diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 4a01a0358..0e6750449 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -5,6 +5,7 @@ ### New Features and Improvements ### Bug Fixes +* Pass `--profile` to CLI token source when profile is set, and add `--host` fallback for older CLIs that don't support `--profile` ([#682](https://github.com/databricks/databricks-sdk-java/pull/682)). ### Security Vulnerabilities diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index db0678331..58aaf7655 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -18,14 +18,39 @@ import java.util.Arrays; import java.util.List; import org.apache.commons.io.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @InternalApi public class CliTokenSource implements TokenSource { + private static final Logger LOG = LoggerFactory.getLogger(CliTokenSource.class); + private List cmd; private String tokenTypeField; private String accessTokenField; private String expiryField; private Environment env; + // fallbackCmd is tried when the primary command fails with "unknown flag: --profile", + // indicating the CLI is too old to support --profile. Can be removed once support + // for CLI versions predating --profile is dropped. + // See: https://github.com/databricks/databricks-sdk-go/pull/1497 + private List fallbackCmd; + + /** + * Internal exception that carries the clean stderr message but exposes full output for checks. + */ + static class CliCommandException extends IOException { + private final String fullOutput; + + CliCommandException(String message, String fullOutput) { + super(message); + this.fullOutput = fullOutput; + } + + String getFullOutput() { + return fullOutput; + } + } public CliTokenSource( List cmd, @@ -33,12 +58,24 @@ public CliTokenSource( String accessTokenField, String expiryField, Environment env) { + this(cmd, tokenTypeField, accessTokenField, expiryField, env, null); + } + + public CliTokenSource( + List cmd, + String tokenTypeField, + String accessTokenField, + String expiryField, + Environment env, + List fallbackCmd) { super(); this.cmd = OSUtils.get(env).getCliExecutableCommand(cmd); this.tokenTypeField = tokenTypeField; this.accessTokenField = accessTokenField; this.expiryField = expiryField; this.env = env; + this.fallbackCmd = + fallbackCmd != null ? OSUtils.get(env).getCliExecutableCommand(fallbackCmd) : null; } /** @@ -87,10 +124,9 @@ private String getProcessStream(InputStream stream) throws IOException { return new String(bytes); } - @Override - public Token getToken() { + private Token execCliCommand(List cmdToRun) throws IOException { try { - ProcessBuilder processBuilder = new ProcessBuilder(cmd); + ProcessBuilder processBuilder = new ProcessBuilder(cmdToRun); processBuilder.environment().putAll(env.getEnv()); Process process = processBuilder.start(); String stdout = getProcessStream(process.getInputStream()); @@ -99,9 +135,10 @@ public Token getToken() { if (exitCode != 0) { if (stderr.contains("not found")) { throw new DatabricksException(stderr); - } else { - throw new IOException(stderr); } + // getMessage() returns the clean stderr-based message; getFullOutput() exposes + // both streams so the caller can check for "unknown flag: --profile" in either. + throw new CliCommandException("cannot get access token: " + stderr, stdout + "\n" + stderr); } JsonNode jsonNode = new ObjectMapper().readTree(stdout); String tokenType = jsonNode.get(tokenTypeField).asText(); @@ -111,8 +148,33 @@ public Token getToken() { return new Token(accessToken, tokenType, expiresOn); } catch (DatabricksException e) { throw e; - } catch (InterruptedException | IOException e) { - throw new DatabricksException("cannot get access token: " + e.getMessage(), e); + } catch (InterruptedException e) { + throw new IOException("cannot get access token: " + e.getMessage(), e); + } + } + + @Override + public Token getToken() { + try { + return execCliCommand(this.cmd); + } catch (IOException e) { + String textToCheck = + e instanceof CliCommandException + ? ((CliCommandException) e).getFullOutput() + : e.getMessage(); + if (fallbackCmd != null + && textToCheck != null + && textToCheck.contains("unknown flag: --profile")) { + LOG.warn( + "Databricks CLI does not support --profile flag. Falling back to --host. " + + "Please upgrade your CLI to the latest version."); + try { + return execCliCommand(this.fallbackCmd); + } catch (IOException fallbackException) { + throw new DatabricksException(fallbackException.getMessage(), fallbackException); + } + } + throw new DatabricksException(e.getMessage(), e); } } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java index 7fc505583..6cfbadc3d 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java @@ -21,13 +21,13 @@ public String authType() { } /** - * Builds the CLI command arguments for the databricks auth token command. + * Builds the CLI command arguments using --host (legacy path). * * @param cliPath Path to the databricks CLI executable * @param config Configuration containing host, account ID, workspace ID, etc. * @return List of command arguments */ - List buildCliCommand(String cliPath, DatabricksConfig config) { + List buildHostArgs(String cliPath, DatabricksConfig config) { List cmd = new ArrayList<>(Arrays.asList(cliPath, "auth", "token", "--host", config.getHost())); if (config.getExperimentalIsUnifiedHost() != null && config.getExperimentalIsUnifiedHost()) { @@ -57,8 +57,26 @@ private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { LOG.debug("Databricks CLI could not be found"); return null; } - List cmd = buildCliCommand(cliPath, config); - return new CliTokenSource(cmd, "token_type", "access_token", "expiry", config.getEnv()); + + List cmd; + List fallbackCmd = null; + + if (config.getProfile() != null) { + // When profile is set, use --profile as the primary command. + // The profile contains the full config (host, account_id, etc.). + cmd = + new ArrayList<>( + Arrays.asList(cliPath, "auth", "token", "--profile", config.getProfile())); + // Build a --host fallback for older CLIs that don't support --profile. + if (config.getHost() != null) { + fallbackCmd = buildHostArgs(cliPath, config); + } + } else { + cmd = buildHostArgs(cliPath, config); + } + + return new CliTokenSource( + cmd, "token_type", "access_token", "expiry", config.getEnv(), fallbackCmd); } @Override diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 16f110991..8476c6de5 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -2,6 +2,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; @@ -27,9 +28,11 @@ import java.util.List; import java.util.Map; import java.util.TimeZone; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -213,4 +216,175 @@ public void testParseExpiry(String input, Instant expectedInstant, String descri assertEquals(expectedInstant, parsedInstant); } } + + // ---- Fallback tests for --profile flag handling ---- + + private CliTokenSource makeTokenSource( + Environment env, List primaryCmd, List fallbackCmd) { + OSUtilities osUtils = mock(OSUtilities.class); + when(osUtils.getCliExecutableCommand(any())).thenAnswer(inv -> inv.getArgument(0)); + try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { + mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); + return new CliTokenSource( + primaryCmd, "token_type", "access_token", "expiry", env, fallbackCmd); + } + } + + private String validTokenJson(String accessToken) { + String expiry = + ZonedDateTime.now() + .plusHours(1) + .format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSSXXX")); + return String.format( + "{\"token_type\":\"Bearer\",\"access_token\":\"%s\",\"expiry\":\"%s\"}", + accessToken, expiry); + } + + @Test + public void testFallbackOnUnknownProfileFlagInStderr() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + List primaryCmd = + Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); + List fallbackCmdList = + Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); + + CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + if (callCount.getAndIncrement() == 0) { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn( + new ByteArrayInputStream(validTokenJson("fallback-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token token = tokenSource.getToken(); + assertEquals("fallback-token", token.getAccessToken()); + assertEquals(2, mocked.constructed().size()); + } + } + + @Test + public void testFallbackTriggeredWhenUnknownFlagInStdout() { + // Fallback triggers even when "unknown flag" appears in stdout rather than stderr. + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + List primaryCmd = + Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); + List fallbackCmdList = + Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); + + CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + if (callCount.getAndIncrement() == 0) { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn( + new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); + when(failProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn( + new ByteArrayInputStream(validTokenJson("fallback-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token token = tokenSource.getToken(); + assertEquals("fallback-token", token.getAccessToken()); + assertEquals(2, mocked.constructed().size()); + } + } + + @Test + public void testNoFallbackOnRealAuthError() { + // When the primary fails with a real error (not unknown flag), no fallback is attempted. + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + List primaryCmd = + Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); + List fallbackCmdList = + Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); + + CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()).thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream( + "databricks OAuth is not configured for this host".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + })) { + DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken); + assertTrue(ex.getMessage().contains("databricks OAuth is not configured")); + assertEquals(1, mocked.constructed().size()); + } + } + + @Test + public void testNoFallbackWhenFallbackCmdNotSet() { + // When fallbackCmd is null and the primary fails with unknown flag, original error propagates. + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + List primaryCmd = + Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); + + CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, null); + + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()).thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + })) { + DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken); + assertTrue(ex.getMessage().contains("unknown flag: --profile")); + assertEquals(1, mocked.constructed().size()); + } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java index 478948d82..bac4a766b 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java @@ -18,19 +18,19 @@ class DatabricksCliCredentialsProviderTest { private final DatabricksCliCredentialsProvider provider = new DatabricksCliCredentialsProvider(); @Test - void testBuildCliCommand_WorkspaceHost() { + void testBuildHostArgs_WorkspaceHost() { DatabricksConfig config = new DatabricksConfig().setHost(HOST); - List cmd = provider.buildCliCommand(CLI_PATH, config); + List cmd = provider.buildHostArgs(CLI_PATH, config); assertEquals(Arrays.asList(CLI_PATH, "auth", "token", "--host", HOST), cmd); } @Test - void testBuildCliCommand_AccountHost() { + void testBuildHostArgs_AccountHost() { DatabricksConfig config = new DatabricksConfig().setHost(ACCOUNT_HOST).setAccountId(ACCOUNT_ID); - List cmd = provider.buildCliCommand(CLI_PATH, config); + List cmd = provider.buildHostArgs(CLI_PATH, config); assertEquals( Arrays.asList( @@ -39,7 +39,7 @@ void testBuildCliCommand_AccountHost() { } @Test - void testBuildCliCommand_UnifiedHost_WithAccountIdAndWorkspaceId() { + void testBuildHostArgs_UnifiedHost_WithAccountIdAndWorkspaceId() { DatabricksConfig config = new DatabricksConfig() .setHost(UNIFIED_HOST) @@ -47,7 +47,7 @@ void testBuildCliCommand_UnifiedHost_WithAccountIdAndWorkspaceId() { .setAccountId(ACCOUNT_ID) .setWorkspaceId(WORKSPACE_ID); - List cmd = provider.buildCliCommand(CLI_PATH, config); + List cmd = provider.buildHostArgs(CLI_PATH, config); assertEquals( Arrays.asList( @@ -65,14 +65,14 @@ void testBuildCliCommand_UnifiedHost_WithAccountIdAndWorkspaceId() { } @Test - void testBuildCliCommand_UnifiedHost_WithAccountIdOnly() { + void testBuildHostArgs_UnifiedHost_WithAccountIdOnly() { DatabricksConfig config = new DatabricksConfig() .setHost(UNIFIED_HOST) .setExperimentalIsUnifiedHost(true) .setAccountId(ACCOUNT_ID); - List cmd = provider.buildCliCommand(CLI_PATH, config); + List cmd = provider.buildHostArgs(CLI_PATH, config); assertEquals( Arrays.asList( @@ -88,14 +88,14 @@ void testBuildCliCommand_UnifiedHost_WithAccountIdOnly() { } @Test - void testBuildCliCommand_UnifiedHost_WithWorkspaceIdOnly() { + void testBuildHostArgs_UnifiedHost_WithWorkspaceIdOnly() { DatabricksConfig config = new DatabricksConfig() .setHost(UNIFIED_HOST) .setExperimentalIsUnifiedHost(true) .setWorkspaceId(WORKSPACE_ID); - List cmd = provider.buildCliCommand(CLI_PATH, config); + List cmd = provider.buildHostArgs(CLI_PATH, config); assertEquals( Arrays.asList( @@ -111,11 +111,11 @@ void testBuildCliCommand_UnifiedHost_WithWorkspaceIdOnly() { } @Test - void testBuildCliCommand_UnifiedHost_WithNoAccountIdOrWorkspaceId() { + void testBuildHostArgs_UnifiedHost_WithNoAccountIdOrWorkspaceId() { DatabricksConfig config = new DatabricksConfig().setHost(UNIFIED_HOST).setExperimentalIsUnifiedHost(true); - List cmd = provider.buildCliCommand(CLI_PATH, config); + List cmd = provider.buildHostArgs(CLI_PATH, config); assertEquals( Arrays.asList( @@ -124,7 +124,7 @@ void testBuildCliCommand_UnifiedHost_WithNoAccountIdOrWorkspaceId() { } @Test - void testBuildCliCommand_UnifiedHostFalse_WithAccountHost() { + void testBuildHostArgs_UnifiedHostFalse_WithAccountHost() { // When experimentalIsUnifiedHost is explicitly false, should fall back to account-id logic DatabricksConfig config = new DatabricksConfig() @@ -132,7 +132,7 @@ void testBuildCliCommand_UnifiedHostFalse_WithAccountHost() { .setExperimentalIsUnifiedHost(false) .setAccountId(ACCOUNT_ID); - List cmd = provider.buildCliCommand(CLI_PATH, config); + List cmd = provider.buildHostArgs(CLI_PATH, config); assertEquals( Arrays.asList(