Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features and Improvements

### Bug Fixes
* Pass `--profile` to CLI token source when profile is set, and add `--host` fallback for older CLIs that don't support `--profile` ([#682](https://github.com/databricks/databricks-sdk-java/pull/682)).

### Security Vulnerabilities

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,64 @@
import java.util.Arrays;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InternalApi
public class CliTokenSource implements TokenSource {
private static final Logger LOG = LoggerFactory.getLogger(CliTokenSource.class);

private List<String> cmd;
private String tokenTypeField;
private String accessTokenField;
private String expiryField;
private Environment env;
// fallbackCmd is tried when the primary command fails with "unknown flag: --profile",
// indicating the CLI is too old to support --profile. Can be removed once support
// for CLI versions predating --profile is dropped.
// See: https://github.com/databricks/databricks-sdk-go/pull/1497
private List<String> fallbackCmd;

/**
* Internal exception that carries the clean stderr message but exposes full output for checks.
*/
static class CliCommandException extends IOException {
private final String fullOutput;

CliCommandException(String message, String fullOutput) {
super(message);
this.fullOutput = fullOutput;
}

String getFullOutput() {
return fullOutput;
}
}

public CliTokenSource(
List<String> cmd,
String tokenTypeField,
String accessTokenField,
String expiryField,
Environment env) {
this(cmd, tokenTypeField, accessTokenField, expiryField, env, null);
}

public CliTokenSource(
List<String> cmd,
String tokenTypeField,
String accessTokenField,
String expiryField,
Environment env,
List<String> fallbackCmd) {
super();
this.cmd = OSUtils.get(env).getCliExecutableCommand(cmd);
this.tokenTypeField = tokenTypeField;
this.accessTokenField = accessTokenField;
this.expiryField = expiryField;
this.env = env;
this.fallbackCmd =
fallbackCmd != null ? OSUtils.get(env).getCliExecutableCommand(fallbackCmd) : null;
}

/**
Expand Down Expand Up @@ -87,10 +124,9 @@ private String getProcessStream(InputStream stream) throws IOException {
return new String(bytes);
}

@Override
public Token getToken() {
private Token execCliCommand(List<String> cmdToRun) throws IOException {
try {
ProcessBuilder processBuilder = new ProcessBuilder(cmd);
ProcessBuilder processBuilder = new ProcessBuilder(cmdToRun);
processBuilder.environment().putAll(env.getEnv());
Process process = processBuilder.start();
String stdout = getProcessStream(process.getInputStream());
Expand All @@ -99,9 +135,10 @@ public Token getToken() {
if (exitCode != 0) {
if (stderr.contains("not found")) {
throw new DatabricksException(stderr);
} else {
throw new IOException(stderr);
}
// getMessage() returns the clean stderr-based message; getFullOutput() exposes
// both streams so the caller can check for "unknown flag: --profile" in either.
throw new CliCommandException("cannot get access token: " + stderr, stdout + "\n" + stderr);
}
JsonNode jsonNode = new ObjectMapper().readTree(stdout);
String tokenType = jsonNode.get(tokenTypeField).asText();
Expand All @@ -111,8 +148,33 @@ public Token getToken() {
return new Token(accessToken, tokenType, expiresOn);
} catch (DatabricksException e) {
throw e;
} catch (InterruptedException | IOException e) {
throw new DatabricksException("cannot get access token: " + e.getMessage(), e);
} catch (InterruptedException e) {
throw new IOException("cannot get access token: " + e.getMessage(), e);
}
}

@Override
public Token getToken() {
try {
return execCliCommand(this.cmd);
} catch (IOException e) {
String textToCheck =
e instanceof CliCommandException
? ((CliCommandException) e).getFullOutput()
: e.getMessage();
if (fallbackCmd != null
&& textToCheck != null
&& textToCheck.contains("unknown flag: --profile")) {
LOG.warn(
"Databricks CLI does not support --profile flag. Falling back to --host. "
+ "Please upgrade your CLI to the latest version.");
try {
return execCliCommand(this.fallbackCmd);
} catch (IOException fallbackException) {
throw new DatabricksException(fallbackException.getMessage(), fallbackException);
}
}
throw new DatabricksException(e.getMessage(), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ public String authType() {
}

/**
* Builds the CLI command arguments for the databricks auth token command.
* Builds the CLI command arguments using --host (legacy path).
*
* @param cliPath Path to the databricks CLI executable
* @param config Configuration containing host, account ID, workspace ID, etc.
* @return List of command arguments
*/
List<String> buildCliCommand(String cliPath, DatabricksConfig config) {
List<String> buildHostArgs(String cliPath, DatabricksConfig config) {
List<String> cmd =
new ArrayList<>(Arrays.asList(cliPath, "auth", "token", "--host", config.getHost()));
if (config.getExperimentalIsUnifiedHost() != null && config.getExperimentalIsUnifiedHost()) {
Expand Down Expand Up @@ -57,8 +57,26 @@ private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) {
LOG.debug("Databricks CLI could not be found");
return null;
}
List<String> cmd = buildCliCommand(cliPath, config);
return new CliTokenSource(cmd, "token_type", "access_token", "expiry", config.getEnv());

List<String> cmd;
List<String> fallbackCmd = null;

if (config.getProfile() != null) {
// When profile is set, use --profile as the primary command.
// The profile contains the full config (host, account_id, etc.).
cmd =
new ArrayList<>(
Arrays.asList(cliPath, "auth", "token", "--profile", config.getProfile()));
// Build a --host fallback for older CLIs that don't support --profile.
if (config.getHost() != null) {
fallbackCmd = buildHostArgs(cliPath, config);
}
} else {
cmd = buildHostArgs(cliPath, config);
}

return new CliTokenSource(
cmd, "token_type", "access_token", "expiry", config.getEnv(), fallbackCmd);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
Expand All @@ -27,9 +28,11 @@
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand Down Expand Up @@ -213,4 +216,175 @@ public void testParseExpiry(String input, Instant expectedInstant, String descri
assertEquals(expectedInstant, parsedInstant);
}
}

// ---- Fallback tests for --profile flag handling ----

private CliTokenSource makeTokenSource(
Environment env, List<String> primaryCmd, List<String> fallbackCmd) {
OSUtilities osUtils = mock(OSUtilities.class);
when(osUtils.getCliExecutableCommand(any())).thenAnswer(inv -> inv.getArgument(0));
try (MockedStatic<OSUtils> mockedOSUtils = mockStatic(OSUtils.class)) {
mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils);
return new CliTokenSource(
primaryCmd, "token_type", "access_token", "expiry", env, fallbackCmd);
}
}

private String validTokenJson(String accessToken) {
String expiry =
ZonedDateTime.now()
.plusHours(1)
.format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"));
return String.format(
"{\"token_type\":\"Bearer\",\"access_token\":\"%s\",\"expiry\":\"%s\"}",
accessToken, expiry);
}

@Test
public void testFallbackOnUnknownProfileFlagInStderr() {
Environment env = mock(Environment.class);
when(env.getEnv()).thenReturn(new HashMap<>());

List<String> primaryCmd =
Arrays.asList("databricks", "auth", "token", "--profile", "my-profile");
List<String> fallbackCmdList =
Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com");

CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList);

AtomicInteger callCount = new AtomicInteger(0);
try (MockedConstruction<ProcessBuilder> mocked =
mockConstruction(
ProcessBuilder.class,
(pb, context) -> {
if (callCount.getAndIncrement() == 0) {
Process failProcess = mock(Process.class);
when(failProcess.getInputStream())
.thenReturn(new ByteArrayInputStream(new byte[0]));
when(failProcess.getErrorStream())
.thenReturn(
new ByteArrayInputStream("Error: unknown flag: --profile".getBytes()));
when(failProcess.waitFor()).thenReturn(1);
when(pb.start()).thenReturn(failProcess);
} else {
Process successProcess = mock(Process.class);
when(successProcess.getInputStream())
.thenReturn(
new ByteArrayInputStream(validTokenJson("fallback-token").getBytes()));
when(successProcess.getErrorStream())
.thenReturn(new ByteArrayInputStream(new byte[0]));
when(successProcess.waitFor()).thenReturn(0);
when(pb.start()).thenReturn(successProcess);
}
})) {
Token token = tokenSource.getToken();
assertEquals("fallback-token", token.getAccessToken());
assertEquals(2, mocked.constructed().size());
}
}

@Test
public void testFallbackTriggeredWhenUnknownFlagInStdout() {
// Fallback triggers even when "unknown flag" appears in stdout rather than stderr.
Environment env = mock(Environment.class);
when(env.getEnv()).thenReturn(new HashMap<>());

List<String> primaryCmd =
Arrays.asList("databricks", "auth", "token", "--profile", "my-profile");
List<String> fallbackCmdList =
Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com");

CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList);

AtomicInteger callCount = new AtomicInteger(0);
try (MockedConstruction<ProcessBuilder> mocked =
mockConstruction(
ProcessBuilder.class,
(pb, context) -> {
if (callCount.getAndIncrement() == 0) {
Process failProcess = mock(Process.class);
when(failProcess.getInputStream())
.thenReturn(
new ByteArrayInputStream("Error: unknown flag: --profile".getBytes()));
when(failProcess.getErrorStream())
.thenReturn(new ByteArrayInputStream(new byte[0]));
when(failProcess.waitFor()).thenReturn(1);
when(pb.start()).thenReturn(failProcess);
} else {
Process successProcess = mock(Process.class);
when(successProcess.getInputStream())
.thenReturn(
new ByteArrayInputStream(validTokenJson("fallback-token").getBytes()));
when(successProcess.getErrorStream())
.thenReturn(new ByteArrayInputStream(new byte[0]));
when(successProcess.waitFor()).thenReturn(0);
when(pb.start()).thenReturn(successProcess);
}
})) {
Token token = tokenSource.getToken();
assertEquals("fallback-token", token.getAccessToken());
assertEquals(2, mocked.constructed().size());
}
}

@Test
public void testNoFallbackOnRealAuthError() {
// When the primary fails with a real error (not unknown flag), no fallback is attempted.
Environment env = mock(Environment.class);
when(env.getEnv()).thenReturn(new HashMap<>());

List<String> primaryCmd =
Arrays.asList("databricks", "auth", "token", "--profile", "my-profile");
List<String> fallbackCmdList =
Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com");

CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList);

try (MockedConstruction<ProcessBuilder> mocked =
mockConstruction(
ProcessBuilder.class,
(pb, context) -> {
Process failProcess = mock(Process.class);
when(failProcess.getInputStream()).thenReturn(new ByteArrayInputStream(new byte[0]));
when(failProcess.getErrorStream())
.thenReturn(
new ByteArrayInputStream(
"databricks OAuth is not configured for this host".getBytes()));
when(failProcess.waitFor()).thenReturn(1);
when(pb.start()).thenReturn(failProcess);
})) {
DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken);
assertTrue(ex.getMessage().contains("databricks OAuth is not configured"));
assertEquals(1, mocked.constructed().size());
}
}

@Test
public void testNoFallbackWhenFallbackCmdNotSet() {
// When fallbackCmd is null and the primary fails with unknown flag, original error propagates.
Environment env = mock(Environment.class);
when(env.getEnv()).thenReturn(new HashMap<>());

List<String> primaryCmd =
Arrays.asList("databricks", "auth", "token", "--profile", "my-profile");

CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, null);

try (MockedConstruction<ProcessBuilder> mocked =
mockConstruction(
ProcessBuilder.class,
(pb, context) -> {
Process failProcess = mock(Process.class);
when(failProcess.getInputStream()).thenReturn(new ByteArrayInputStream(new byte[0]));
when(failProcess.getErrorStream())
.thenReturn(
new ByteArrayInputStream("Error: unknown flag: --profile".getBytes()));
when(failProcess.waitFor()).thenReturn(1);
when(pb.start()).thenReturn(failProcess);
})) {
DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken);
assertTrue(ex.getMessage().contains("unknown flag: --profile"));
assertEquals(1, mocked.constructed().size());
}
}
}
Loading
Loading