diff --git a/DESCRIPTION b/DESCRIPTION index d0fbc2c..2534b55 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -14,7 +14,6 @@ Description: 'MLFlow' is an open-source tool for track not needing 'reticulate', relying on 'aws.s3' where possible, etc. License: Apache License (>= 2) Imports: - aws.s3, base64enc, checkmate, git2r, @@ -22,6 +21,7 @@ Imports: jsonlite, lifecycle, magrittr, + paws.storage, purrr, rlang, stringr, @@ -39,6 +39,7 @@ LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.1.2 Collate: + 'artifacts.R' 'client.R' 'experiments.R' 'globals.R' diff --git a/NAMESPACE b/NAMESPACE index c15a92f..cfa0180 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -51,6 +51,7 @@ export(rename_experiment) export(rename_registered_model) export(restore_experiment) export(restore_run) +export(s3_select_from_artifact) export(search_runs) export(set_active_experiment_id) export(set_experiment_tag) @@ -62,9 +63,6 @@ export(transition_model_version_stage) export(update_model_version) export(update_registered_model) export(with.mlflow_run) -importFrom(aws.s3,put_object) -importFrom(aws.s3,s3read_using) -importFrom(aws.s3,s3write_using) importFrom(base64enc,base64encode) importFrom(checkmate,assert_class) importFrom(checkmate,assert_data_frame) @@ -90,6 +88,7 @@ importFrom(jsonlite,toJSON) importFrom(lifecycle,deprecated) importFrom(magrittr,"%>%") importFrom(magrittr,add) +importFrom(paws.storage,s3) importFrom(purrr,imap) importFrom(purrr,imap_int) importFrom(purrr,insistently) @@ -116,6 +115,7 @@ importFrom(rlang,is_symbol) importFrom(rlang,maybe_missing) importFrom(rlang,names2) importFrom(rlang,warn) +importFrom(stringr,str_extract) importFrom(stringr,str_remove) importFrom(stringr,str_split) importFrom(stringr,str_sub) @@ -124,4 +124,5 @@ importFrom(tibble,tibble) importFrom(tools,file_ext) importFrom(utils,askYesNo) importFrom(utils,packageVersion) +importFrom(utils,read.csv) importFrom(withr,with_options) diff --git a/NEWS.md b/NEWS.md index 69b980e..2ba5a19 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +# lightMLFlow 0.7.0 + +* Uses `paws.storage` instead of `aws.s3` for interacting with S3 +* Adds a method to interact with the [S3 Select API](https://docs.aws.amazon.com/AmazonS3/latest/userguide/selecting-content-from-objects.html) + # lightMLFlow 0.6.6 * Uses `purrr::insistently` to retry artifact saves and loads, since S3's API has a habit of returning a `500` every once in a while. Five retries by default, with an exponential backoff. diff --git a/R/artifacts.R b/R/artifacts.R new file mode 100644 index 0000000..47e4d5d --- /dev/null +++ b/R/artifacts.R @@ -0,0 +1,346 @@ +#' @importFrom stringr str_extract +generate_s3_key_bucket_ext <- function(artifact_name, run_id = get_active_run_id(), client = mlflow_client()) { + artifact_location <- get_artifact_path( + run_id = run_id, + client = client + ) + + without_s3_prefix <- str_remove( + artifact_location, + "s3://" + ) + + bucket <- str_extract( + without_s3_prefix, + ".+?(?=/)" + ) + + path <- without_s3_prefix %>% + str_remove( + ".+?(?=/)" + ) %>% + str_sub(start = 2L) + + key <- paste( + path, artifact_name, sep = "/" + ) + + ext <- paste0(".", file_ext(key)) + + list( + bucket = bucket, + key = key, + ext = ext + ) +} + +#' Load an artifact into an R object +#' +#' @importFrom checkmate assert_function +#' +#' @param artifact_name The name of the artifact to load +#' @param run_id A run ID to find the URI for +#' @param client An MLFlow client +#' @param FUN a function to use to load the artifact +#' @param \dots Additional arguments to pass on to `s3read_using` +#' @param pause_base,max_times,pause_cap See \link[purrr]{insistently} +#' +#' @return An R object. The result of `s3read_using` +#' @export +load_artifact <- function(artifact_name, FUN = readRDS, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) { + + assert_function(FUN) + assert_string(artifact_name) + assert_string(run_id) + assert_mlflow_client(client) + + s3_path_info <- generate_s3_key_bucket_ext( + artifact_name = artifact_name, + run_id = run_id, + client = client + ) + + s3 <- s3() + + tmp <- tempfile(fileext = s3_path_info$ext) + on.exit(unlink(tmp, recursive = TRUE)) + + rate <- rate_backoff( + pause_base = pause_base, + max_times = max_times, + pause_cap = pause_cap + ) + + insistently_read <- insistently( + s3$download_file, + rate = rate, + quiet = FALSE + ) + + insistently_read( + Bucket = s3_path_info$bucket, + Key = s3_path_info$key, + Filename = tmp + ) + + FUN( + tmp, + ... + ) +} + +#' Get the artifact path for a run +#' +#' @param run_id A run id. Automatically inferred if a run is currently active. +#' @param client An MLFlow client. Auto-generated if not provided. +#' +#' @return A path to the run's artifacts in S3 +#' @export +get_artifact_path <- function(run_id = get_active_run_id(), client = mlflow_client()) { + experiment_id <- get_experiment_from_run(run_id = run_id) + + experiment <- get_experiment( + experiment_id = experiment_id, + client = client + ) + + paste( + experiment$artifact_location, + run_id, + "artifacts", + sep = "/" + ) +} +#' List Artifacts +#' +#' Gets a list of artifacts. +#' +#' @param path The run's relative artifact path to list from. If not specified, it is +#' set to the root artifact path +#' @param run_id A run id Automatically inferred if a run is currently active. +#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated. +#' +#' @importFrom purrr transpose +#' @importFrom rlang inform +#' +#' @return A `data.frame` of the artifacts at the path provided for the run provided. +#' @export +list_artifacts <- function(path = NULL, run_id = get_active_run_id(), client = mlflow_client()) { + + assert_string(path, null.ok = TRUE) + assert_string(run_id) + assert_mlflow_client(client) + + response <- call_mlflow_api( + "artifacts", "list", + client = client, + verb = "GET", + query = list( + run_id = run_id, + path = path + ) + ) + + files_list <- if (!is.null(response$files)) response$files else list() + files_list <- map(files_list, function(file_info) { + if (is.null(file_info$file_size)) { + file_info$file_size <- NA + } + file_info + }) + + files_list %>% + transpose() %>% + map(unlist) %>% + as.data.frame() +} + +#' Log Artifact +#' +#' Logs a specific file or directory as an artifact for a run. Modeled after `aws.s3::s3write_using` +#' +#' @param x The object to log as an artifact +#' @param FUN the function to use to save the artifact +#' @param filename the name of the file to save +#' @param run_id A run uuid. Automatically inferred if a run is currently active. +#' @param client An MLFlow client. Auto-generated if not provided +#' @param pause_base,max_times,pause_cap See \link[purrr]{insistently} +#' @param ... Additional arguments to pass to `FUN` +#' +#' @details +#' +#' When logging to Amazon S3, ensure that you have the s3:PutObject, s3:GetObject, +#' s3:ListBucket, and s3:GetBucketLocation permissions on your bucket. +#' +#' Additionally, at least the \code{AWS_ACCESS_KEY_ID} and \code{AWS_SECRET_ACCESS_KEY} +#' environment variables must be set to the corresponding key and secrets provided +#' by Amazon IAM. +#' +#' @importFrom stringr str_remove str_split str_sub +#' @importFrom purrr insistently rate_backoff +#' +#' @return The path to the file, invisibly +#' @export +log_artifact <- function(x, FUN, filename, run_id, client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) { + UseMethod("log_artifact") +} + +#' @rdname log_artifact +#' @export +log_artifact.default <- function(x, FUN = saveRDS, filename, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) { + + check_required(x) + check_required(filename) + + s3_key_bucket_ext <- generate_s3_key_bucket_ext(filename, run_id, client) + + tmp <- tempfile(fileext = s3_key_bucket_ext$ext) + on.exit(unlink(tmp, recursive = TRUE)) + + FUN( + x, + tmp, + ... + ) + + s3 <- s3() + + rate <- rate_backoff( + pause_base = pause_base, + max_times = max_times, + pause_cap = pause_cap + ) + + insistently_write <- insistently( + s3$put_object, + rate = rate, + quiet = FALSE + ) + + insistently_write( + Bucket = s3_key_bucket_ext$bucket, + Key = s3_key_bucket_ext$key, + Body = tmp + ) + + invisible(paste("s3:/", s3_key_bucket_ext$bucket, s3_key_bucket_ext$key, sep = "/")) +} + +#' @importFrom tools file_ext +#' @importFrom paws.storage s3 +#' @rdname log_artifact +#' @export +log_artifact.ggplot <- function(x, FUN, filename, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) { + + check_required(x) + check_required(FUN) + check_required(filename) + + ## based on https://github.com/hrbrmstr/hrbrthemes/blob/master/R/aaa.r + if (isFALSE(requireNamespace("ggplot2", quietly = TRUE))) { + abort( + "Package `ggplot2` required for `ggsave`.\n", + "Please install and try again." + ) + } + + s3_key_bucket_ext <- generate_s3_key_bucket_ext( + artifact_name = filename, + run_id = run_id, + client = client + ) + + temp_file <- tempfile(fileext = s3_key_bucket_ext$ext) + on.exit(unlink(temp_file, recursive = TRUE)) + + ggplot2::ggsave(filename = temp_file, plot = x, ...) + + s3 <- s3() + + rate <- rate_backoff( + pause_base = pause_base, + max_times = max_times, + pause_cap = pause_cap + ) + + insistently_put <- insistently( + s3$put_object, + rate = rate, + quiet = FALSE + ) + + insistently_put( + Bucket = s3_key_bucket_ext$bucket, + Key = s3_key_bucket_ext$key, + Body = temp_file + ) + + invisible(paste("s3:/", s3_key_bucket_ext$bucket, s3_key_bucket_ext$key, sep = "/")) +} + +#' Call the S3 SELECT API on a CSV artifact +#' +#' @param artifact_name The name of the artifact to `SELECT` from +#' @param run_id A run id Automatically inferred if a run is currently active. +#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated. +#' @param pause_base,max_times,pause_cap See \link[purrr]{insistently} +#' @param Expression,ExpressionType,InputSerialization,OutputSerialization See \link[paws.storage]{s3_select_object_content} +#' @param \dots Additional arguments to pass to \link[paws.storage]{s3_select_object_content} +#' +#' @importFrom utils read.csv +#' +#' @return A data.frame of the query result +#' @export +s3_select_from_artifact <- function( + artifact_name, + run_id = get_active_run_id(), + client = mlflow_client(), + pause_base = .5, + max_times = 5, + pause_cap = 60, + Expression, + ExpressionType = "SQL", + InputSerialization = list(CSV = list(FileHeaderInfo = "NONE", RecordDelimiter = "\n", FieldDelimiter = ","), CompressionType = "NONE"), + OutputSerialization = list(CSV = list(RecordDelimiter = "\n", FieldDelimiter = ",", QuoteCharacter = '"', QuoteFields = "ASNEEDED")), + ... +) { + + assert_string(artifact_name) + assert_string(run_id) + assert_mlflow_client(client) + + s3_bucket_key_ext <- generate_s3_key_bucket_ext( + artifact_name = artifact_name, + run_id = run_id, + client = client + ) + + s3 <- s3() + + rate <- rate_backoff( + pause_base = pause_base, + max_times = max_times, + pause_cap = pause_cap + ) + + insistently_read <- insistently( + s3$select_object_content, + rate = rate, + quiet = FALSE + ) + + result <- insistently_read( + Bucket = s3_bucket_key_ext$bucket, + Key = s3_bucket_key_ext$key, + Expression = Expression, + ExpressionType = ExpressionType, + InputSerialization = InputSerialization, + OutputSerialization = OutputSerialization, + ... + ) + + read.csv( + text = result$Payload$Records$Payload, + header = InputSerialization$CSV$FileHeaderInfo == "NONE" %||% FALSE + ) +} diff --git a/R/runs.R b/R/runs.R index f348e23..b382a28 100644 --- a/R/runs.R +++ b/R/runs.R @@ -579,119 +579,6 @@ search_runs <- function(experiment_ids, run_view_type = c("ACTIVE_ONLY", "DELETE do.call("rbind", runs_list) %||% data.frame() } -#' Load an artifact into an R object -#' -#' @importFrom checkmate assert_function -#' @importFrom aws.s3 s3read_using -#' -#' @param artifact_name The name of the artifact to load -#' @param run_id A run ID to find the URI for -#' @param client An MLFlow client -#' @param FUN a function to use to load the artifact -#' @param \dots Additional arguments to pass on to `s3read_using` -#' @param pause_base,max_times,pause_cap See \link[purrr]{insistently} -#' -#' @return An R object. The result of `s3read_using` -#' @export -load_artifact <- function(artifact_name, FUN = readRDS, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) { - - assert_function(FUN) - assert_string(artifact_name) - assert_string(run_id) - assert_mlflow_client(client) - - artifact_location <- get_artifact_path( - run_id = run_id, - client = client - ) - - rate <- rate_backoff( - pause_base = pause_base, - max_times = max_times, - pause_cap = pause_cap - ) - - insistently_read <- insistently( - s3read_using, - rate = rate, - quiet = FALSE - ) - - object <- insistently_read( - FUN = FUN, - ..., - object = paste(artifact_location, artifact_name, sep = "/") - ) - - object -} - -#' Get the artifact path for a run -#' -#' @param run_id A run id. Automatically inferred if a run is currently active. -#' @param client An MLFlow client. Auto-generated if not provided. -#' -#' @return A path to the run's artifacts in S3 -#' @export -get_artifact_path <- function(run_id = get_active_run_id(), client = mlflow_client()) { - experiment_id <- get_experiment_from_run(run_id = run_id) - - experiment <- get_experiment( - experiment_id = experiment_id, - client = client - ) - - paste( - experiment$artifact_location, - run_id, - "artifacts", - sep = "/" - ) -} -#' List Artifacts -#' -#' Gets a list of artifacts. -#' -#' @param path The run's relative artifact path to list from. If not specified, it is -#' set to the root artifact path -#' @param run_id A run id Automatically inferred if a run is currently active. -#' @param client An MLFlow client. Defaults to `NULL` and will be auto-generated. -#' -#' @importFrom purrr transpose -#' @importFrom rlang inform -#' -#' @return A `data.frame` of the artifacts at the path provided for the run provided. -#' @export -list_artifacts <- function(path = NULL, run_id = get_active_run_id(), client = mlflow_client()) { - - assert_string(path, null.ok = TRUE) - assert_string(run_id) - assert_mlflow_client(client) - - response <- call_mlflow_api( - "artifacts", "list", - client = client, - verb = "GET", - query = list( - run_id = run_id, - path = path - ) - ) - - files_list <- if (!is.null(response$files)) response$files else list() - files_list <- map(files_list, function(file_info) { - if (is.null(file_info$file_size)) { - file_info$file_size <- NA - } - file_info - }) - - files_list %>% - transpose() %>% - map(unlist) %>% - as.data.frame() -} - set_terminated <- function(status, end_time, run_id, client) { data <- list( @@ -712,124 +599,6 @@ get_experiment_from_run <- function(run_id) { unique() } -#' Log Artifact -#' -#' Logs a specific file or directory as an artifact for a run. Modeled after `aws.s3::s3write_using` -#' -#' @param x The object to log as an artifact -#' @param FUN the function to use to save the artifact -#' @param filename the name of the file to save -#' @param run_id A run uuid. Automatically inferred if a run is currently active. -#' @param client An MLFlow client. Auto-generated if not provided -#' @param pause_base,max_times,pause_cap See \link[purrr]{insistently} -#' @param ... Additional arguments to pass to `aws.s3::s3write_using` -#' -#' @details -#' -#' When logging to Amazon S3, ensure that you have the s3:PutObject, s3:GetObject, -#' s3:ListBucket, and s3:GetBucketLocation permissions on your bucket. -#' -#' Additionally, at least the \code{AWS_ACCESS_KEY_ID} and \code{AWS_SECRET_ACCESS_KEY} -#' environment variables must be set to the corresponding key and secrets provided -#' by Amazon IAM. -#' -#' @importFrom stringr str_remove str_split str_sub -#' @importFrom aws.s3 s3write_using -#' @importFrom purrr insistently rate_backoff -#' -#' @return The path to the file, invisibly -#' @export -log_artifact <- function(x, FUN, filename, run_id, client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) { - UseMethod("log_artifact") -} - -#' @rdname log_artifact -#' @export -log_artifact.default <- function(x, FUN = saveRDS, filename, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) { - - check_required(x) - check_required(filename) - - artifact_dir = get_artifact_path( - run_id = run_id, - client = client - ) - - artifact_filepath <- paste(artifact_dir, filename, sep = "/") - - rate <- rate_backoff( - pause_base = pause_base, - max_times = max_times, - pause_cap = pause_cap - ) - - insistently_write <- insistently( - s3write_using, - rate = rate, - quiet = FALSE - ) - - insistently_write( - x = x, - FUN = FUN, - ..., - object = artifact_filepath - ) - - invisible(artifact_filepath) -} - -#' @importFrom aws.s3 put_object -#' @importFrom tools file_ext -#' @rdname log_artifact -#' @export -log_artifact.ggplot <- function(x, FUN, filename, run_id = get_active_run_id(), client = mlflow_client(), pause_base = .5, max_times = 5, pause_cap = 60, ...) { - - check_required(x) - check_required(FUN) - check_required(filename) - - ## based on https://github.com/hrbrmstr/hrbrthemes/blob/master/R/aaa.r - if (isFALSE(requireNamespace("ggplot2", quietly = TRUE))) { - abort( - "Package `ggplot2` required for `ggsave`.\n", - "Please install and try again." - ) - } - - artifact_dir = get_artifact_path( - run_id = run_id, - client = client - ) - - artifact_filepath <- paste(artifact_dir, filename, sep = "/") - - ext <- file_ext(artifact_filepath) - temp_file <- tempfile(fileext = ext) - on.exit(unlink(temp_file, recursive = TRUE)) - - ggplot2::ggsave(filename = temp_file, plot = x, ...) - - rate <- rate_backoff( - pause_base = pause_base, - max_times = max_times, - pause_cap = pause_cap - ) - - insistently_put <- insistently( - put_object, - rate = rate, - quiet = FALSE - ) - - insistently_put( - file = temp_file, - object = artifact_filepath - ) - - invisible(artifact_filepath) -} - #' Record logged model metadata with the tracking server. #' #' @param model_spec A model specification. diff --git a/README.Rmd b/README.Rmd index d6b5cf8..017969e 100644 --- a/README.Rmd +++ b/README.Rmd @@ -48,7 +48,7 @@ However, there are also significant advantages to using `lightMLFlow` over `mlfl * `lightMLFlow` features a friendlier API, with significantly fewer functions, no `mlflow::mlflow_*` function prefixing (following Tidyverse conventions, `lightMLFlow` function names are verbs), and improved error handling. * `lightMLFlow` fixes some bugs in `mlflow`'s API wrapping functions. * `lightMLFlow` is significantly more lightweight than `mlflow`. It doesn't depend on `httpuv`, `reticulate`, or `swagger`, and has a more minimal footprint in general. -* `lightMLFlow` uses `aws.s3` to put and save objects to and from S3, which means you don't need to have a `boto3` install on your machine running your `MLFlow` code. This is an essential change, as it means that `lightMLFlow` does not require _any_ Python infrastructure, as opposed to `mlflow`, which does. +* `lightMLFlow` uses `paws` to put and save objects to and from S3, which means you don't need to have a `boto3` install on your machine running your `MLFlow` code. This is an essential change, as it means that `lightMLFlow` does not require _any_ Python infrastructure, as opposed to `mlflow`, which does. * `mlflow` (and, specifically, `MLFlow Projects`) doesn't play particularly nicely with `renv`. The reason for that is that an `MLProject` file that's pointed at a Git repo will try to clone and run the code from scratch. But with `renv`, we like restoring a package cache in CI and baking it into the Docker image that the code lives in so that we don't need to install all of the R packages the project needs every time we run the project. `lightMLFlow` hacks its way around this problem by allowing the user to run `set_git_tracking_tags()`, which tricks the MLFlow REST API into thinking that the code was run from an MLFlow Project even when it wasn't. This lets you keep your normal (e.g.) `renv` workflow in place and get the benefit of linked Git commits in the MLFlow UI without actually needing any of the `MLProject` infrastructure or setup steps. * For artifact and model logging, `lightMLFlow` logs R objects directly so that you don't need to worry about first saving a file to disk and then copying it to your artifact store. * In addition, `lightMLFlow` allows artifacts to be loaded directly into the R session in one shot, instead of first being saved to disk and then loaded afterwards. This eliminates lines of code and the headache associated with going S3 --> disk --> R by abstracting away the disk reads and writes. diff --git a/README.md b/README.md index f34da18..3b03bea 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ over `mlflow`: - `lightMLFlow` is significantly more lightweight than `mlflow`. It doesn’t depend on `httpuv`, `reticulate`, or `swagger`, and has a more minimal footprint in general. -- `lightMLFlow` uses `aws.s3` to put and save objects to and from S3, +- `lightMLFlow` uses `paws` to put and save objects to and from S3, which means you don’t need to have a `boto3` install on your machine running your `MLFlow` code. This is an essential change, as it means that `lightMLFlow` does not require *any* Python infrastructure, as diff --git a/man/get_artifact_path.Rd b/man/get_artifact_path.Rd index f3192bb..a2f0dc8 100644 --- a/man/get_artifact_path.Rd +++ b/man/get_artifact_path.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/runs.R +% Please edit documentation in R/artifacts.R \name{get_artifact_path} \alias{get_artifact_path} \title{Get the artifact path for a run} diff --git a/man/list_artifacts.Rd b/man/list_artifacts.Rd index 04be2d7..8dc10e3 100644 --- a/man/list_artifacts.Rd +++ b/man/list_artifacts.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/runs.R +% Please edit documentation in R/artifacts.R \name{list_artifacts} \alias{list_artifacts} \title{List Artifacts} diff --git a/man/load_artifact.Rd b/man/load_artifact.Rd index 80c893d..f3a4a5e 100644 --- a/man/load_artifact.Rd +++ b/man/load_artifact.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/runs.R +% Please edit documentation in R/artifacts.R \name{load_artifact} \alias{load_artifact} \title{Load an artifact into an R object} diff --git a/man/log_artifact.Rd b/man/log_artifact.Rd index 493ee0c..4030621 100644 --- a/man/log_artifact.Rd +++ b/man/log_artifact.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/runs.R +% Please edit documentation in R/artifacts.R \name{log_artifact} \alias{log_artifact} \alias{log_artifact.default} @@ -55,7 +55,7 @@ log_artifact( \item{pause_base, max_times, pause_cap}{See \link[purrr]{insistently}} -\item{...}{Additional arguments to pass to \code{aws.s3::s3write_using}} +\item{...}{Additional arguments to pass to \code{FUN}} } \value{ The path to the file, invisibly diff --git a/man/s3_select_from_artifact.Rd b/man/s3_select_from_artifact.Rd new file mode 100644 index 0000000..dddd81e --- /dev/null +++ b/man/s3_select_from_artifact.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/artifacts.R +\name{s3_select_from_artifact} +\alias{s3_select_from_artifact} +\title{Call the S3 SELECT API on a CSV artifact} +\usage{ +s3_select_from_artifact( + artifact_name, + run_id = get_active_run_id(), + client = mlflow_client(), + pause_base = 0.5, + max_times = 5, + pause_cap = 60, + Expression, + ExpressionType = "SQL", + InputSerialization = list(CSV = list(FileHeaderInfo = "NONE", RecordDelimiter = + "\\n", FieldDelimiter = ","), CompressionType = "NONE"), + OutputSerialization = list(CSV = list(RecordDelimiter = "\\n", FieldDelimiter = ",", + QuoteCharacter = "\\"", QuoteFields = "ASNEEDED")), + ... +) +} +\arguments{ +\item{artifact_name}{The name of the artifact to \code{SELECT} from} + +\item{run_id}{A run id Automatically inferred if a run is currently active.} + +\item{client}{An MLFlow client. Defaults to \code{NULL} and will be auto-generated.} + +\item{pause_base, max_times, pause_cap}{See \link[purrr]{insistently}} + +\item{Expression, ExpressionType, InputSerialization, OutputSerialization}{See \link[paws.storage]{s3_select_object_content}} + +\item{\dots}{Additional arguments to pass to \link[paws.storage]{s3_select_object_content}} +} +\value{ +A data.frame of the query result +} +\description{ +Call the S3 SELECT API on a CSV artifact +} diff --git a/renv.lock b/renv.lock index 4986a01..3f10130 100644 --- a/renv.lock +++ b/renv.lock @@ -51,19 +51,12 @@ "Repository": "CRAN", "Hash": "e8a22846fff485f0be3770c2da758713" }, - "aws.s3": { - "Package": "aws.s3", - "Version": "0.3.21", + "assertthat": { + "Package": "assertthat", + "Version": "0.2.1", "Source": "Repository", "Repository": "CRAN", - "Hash": "a0b873f71741bba67e3bc128d8f09fe3" - }, - "aws.signature": { - "Package": "aws.signature", - "Version": "0.6.0", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "0006bcef272aad12f78dd5a85fc7f4fc" + "Hash": "50c838a310445e954bc13f26f26a6ecf" }, "backports": { "Package": "backports", @@ -205,13 +198,6 @@ "Repository": "CRAN", "Hash": "77bd60a6157420d4ffa93b27cf6a58b8" }, - "fs": { - "Package": "fs", - "Version": "1.5.0", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "44594a07a42e5f91fac9f93fda6d0109" - }, "ggplot2": { "Package": "ggplot2", "Version": "3.3.5", @@ -366,6 +352,20 @@ "Repository": "CRAN", "Hash": "f4dbc5a47fd93d3415249884d31d6791" }, + "paws.common": { + "Package": "paws.common", + "Version": "0.3.17", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "5e4560e5d04ecef2c8975b7afd933b49" + }, + "paws.storage": { + "Package": "paws.storage", + "Version": "0.1.12", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "533b13818105b2cff51ae444070e62f6" + }, "pillar": { "Package": "pillar", "Version": "1.6.2", @@ -478,6 +478,13 @@ "Repository": "CRAN", "Hash": "6f76f71042411426ec8df6c54f34e6dd" }, + "secret": { + "Package": "secret", + "Version": "1.1.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "b66e4c16becd7f937c47f4a7b9cfe35d" + }, "stringi": { "Package": "stringi", "Version": "1.7.3", diff --git a/tests/testthat/test-run.R b/tests/testthat/test-run.R index 4fc457c..16286fb 100644 --- a/tests/testthat/test-run.R +++ b/tests/testthat/test-run.R @@ -279,6 +279,29 @@ test_that("Runs work", { get_active_run_id() ) + ## Tests for SELECT API + log_artifact( + iris, + write.csv, + "iris.csv" + ) + + result <- s3_select_from_artifact( + "iris.csv", + Expression = "SELECT \"Sepal.Length\" AS sl FROM s3object s WHERE CAST(\"Sepal.Length\" AS FLOAT) >= 5", + InputSerialization = list(CSV = list(FileHeaderInfo = "USE", RecordDelimiter = "\n", FieldDelimiter = ","), CompressionType = "NONE") + ) + + expect_equal( + nrow(result), + nrow(subset(iris, iris$Sepal.Length >= 5)) + ) + + expect_equal( + ncol(result), + 1 + ) + end_run() })