diff --git a/README.md b/README.md index 761d453..9fc80b0 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,7 @@ For Scikit-learn support (experimental, unmaintained), see [sklearn examples](ex | StringMap | Maps a list of string values to a list of other string values with a standard CASE WHEN statement. Can provide a default value for ELSE. | [Link](src/kamae/tensorflow/layers/string_map.py) | [Link](src/kamae/spark/transformers/string_map.py) | Not yet implemented | | StringIsInList | Checks if the feature is equal to at least one of the strings provided. | [Link](src/kamae/tensorflow/layers/string_isin_list.py) | [Link](src/kamae/spark/transformers/string_isin_list.py) | Not yet implemented | | StringReplace | Performs a regex replace operation on a feature with constant params or between multiple features | [Link](src/kamae/tensorflow/layers/string_replace.py) | [Link](src/kamae/spark/transformers/string_replace.py) | Not yet implemented | +| StringSequenceToEmbedding | Parses a delimited string of pre-computed embedding vectors into a `(seq_len, embedding_dim)` float tensor, with optional reversal of the non-pad portion of the sequence. | [Link](src/kamae/tensorflow/layers/string_sequence_to_embedding.py) | [Link](src/kamae/spark/transformers/string_sequence_to_embedding.py) | Not yet implemented | | StringToStringList | Splits a string by a separator, returning a list of parametrised length (with a default value for missing inputs). | [Link](src/kamae/tensorflow/layers/string_to_string_list.py) | [Link](src/kamae/spark/transformers/string_to_string_list.py) | Not yet implemented | | SubStringDelimAtIndex | Splits a string column using the provided delimiter, and returns the value at the index given. If the index is out of bounds, returns a given default value | [Link](src/kamae/tensorflow/layers/sub_string_delim_at_index.py) | [Link](src/kamae/spark/transformers/sub_string_delim_at_index.py) | Not yet implemented | | Subtract | Subtracts a constant from a single feature or subtracts multiple features from each other. | [Link](src/kamae/tensorflow/layers/subtract.py) | [Link](src/kamae/spark/transformers/subtract.py) | Not yet implemented | diff --git a/src/kamae/spark/transformers/__init__.py b/src/kamae/spark/transformers/__init__.py index 76ce6b7..0b5a824 100644 --- a/src/kamae/spark/transformers/__init__.py +++ b/src/kamae/spark/transformers/__init__.py @@ -85,6 +85,9 @@ from .string_list_to_string import StringListToStringTransformer # noqa: F401 from .string_map import StringMapTransformer # noqa: F401 from .string_replace import StringReplaceTransformer # noqa: F401 +from .string_sequence_to_embedding import ( # noqa: F401 + StringSequenceToEmbeddingTransformer, +) from .string_to_string_list import StringToStringListTransformer # noqa: F401 from .sub_string_delim_at_index import SubStringDelimAtIndexTransformer # noqa: F401 from .subtract import SubtractTransformer # noqa: F401 diff --git a/src/kamae/spark/transformers/string_sequence_to_embedding.py b/src/kamae/spark/transformers/string_sequence_to_embedding.py new file mode 100644 index 0000000..9a69eea --- /dev/null +++ b/src/kamae/spark/transformers/string_sequence_to_embedding.py @@ -0,0 +1,296 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=unused-argument +# pylint: disable=invalid-name +# pylint: disable=too-many-ancestors +# pylint: disable=no-member +import re +from typing import List, Optional + +import pyspark.sql.functions as F +import tensorflow as tf +from pyspark import keyword_only +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.sql import Column, DataFrame +from pyspark.sql.types import DataType, StringType + +from kamae.spark.params import SingleInputSingleOutputParams +from kamae.spark.utils import single_input_single_output_scalar_transform +from kamae.tensorflow.layers import StringSequenceToEmbeddingLayer + +from .base import BaseTransformer + + +class StringSequenceToEmbeddingParams(Params): + """ + Mixin class containing the parameters required to parse a delimited + string of embedding vectors into a dense float matrix. + """ + + seqLen = Param( + Params._dummy(), + "seqLen", + "Maximum number of vectors per sequence.", + typeConverter=TypeConverters.toInt, + ) + + embeddingDim = Param( + Params._dummy(), + "embeddingDim", + "Dimensionality of each embedding vector.", + typeConverter=TypeConverters.toInt, + ) + + separator = Param( + Params._dummy(), + "separator", + "Separator between floats within a vector.", + typeConverter=TypeConverters.toString, + ) + + sequenceSeparator = Param( + Params._dummy(), + "sequenceSeparator", + "Separator between vectors in a sequence.", + typeConverter=TypeConverters.toString, + ) + + padValue = Param( + Params._dummy(), + "padValue", + "String used to pad short sequences.", + typeConverter=TypeConverters.toString, + ) + + reverse = Param( + Params._dummy(), + "reverse", + "Reverse the non-pad portion of each sequence along the sequence axis.", + typeConverter=TypeConverters.toBoolean, + ) + + def getSeqLen(self) -> int: + return self.getOrDefault(self.seqLen) + + def setSeqLen(self, value: int) -> "StringSequenceToEmbeddingParams": + if value < 1: + raise ValueError("seqLen must be >= 1.") + return self._set(seqLen=value) + + def getEmbeddingDim(self) -> int: + return self.getOrDefault(self.embeddingDim) + + def setEmbeddingDim(self, value: int) -> "StringSequenceToEmbeddingParams": + if value < 1: + raise ValueError("embeddingDim must be >= 1.") + return self._set(embeddingDim=value) + + def getSeparator(self) -> str: + return self.getOrDefault(self.separator) + + def setSeparator(self, value: str) -> "StringSequenceToEmbeddingParams": + return self._set(separator=value) + + def getSequenceSeparator(self) -> str: + return self.getOrDefault(self.sequenceSeparator) + + def setSequenceSeparator(self, value: str) -> "StringSequenceToEmbeddingParams": + return self._set(sequenceSeparator=value) + + def getPadValue(self) -> str: + return self.getOrDefault(self.padValue) + + def setPadValue(self, value: str) -> "StringSequenceToEmbeddingParams": + return self._set(padValue=value) + + def getReverse(self) -> bool: + return self.getOrDefault(self.reverse) + + def setReverse(self, value: bool) -> "StringSequenceToEmbeddingParams": + return self._set(reverse=value) + + +class StringSequenceToEmbeddingTransformer( + BaseTransformer, + SingleInputSingleOutputParams, + StringSequenceToEmbeddingParams, +): + """ + Spark transformer that parses a delimited string of pre-computed + embedding vectors into a nested array of floats with shape + ``(seq_len, embedding_dim)``. + """ + + @keyword_only + def __init__( + self, + inputCol: Optional[str] = None, + outputCol: Optional[str] = None, + inputDtype: Optional[str] = None, + outputDtype: Optional[str] = None, + layerName: Optional[str] = None, + seqLen: int = 10, + embeddingDim: int = 32, + separator: str = "|", + sequenceSeparator: str = ",", + padValue: str = "0", + reverse: bool = False, + ) -> None: + """ + Initialises a StringSequenceToEmbeddingTransformer. + + :param inputCol: Input column name. + :param outputCol: Output column name. + :param inputDtype: Input data type to cast input column to before + transforming. + :param outputDtype: Output data type to cast the output column to after + transforming. + :param layerName: Name of the layer. Used as the name of the tensorflow + layer in the keras model. If not set, we use the uid of the Spark + transformer. + :param seqLen: Maximum number of vectors per sequence. Defaults to 10. + :param embeddingDim: Dimensionality of each embedding vector. + Defaults to 32. + :param separator: Separator between floats within a vector. + Defaults to ``"|"``. + :param sequenceSeparator: Separator between vectors in a sequence. + Defaults to ``","``. + :param padValue: String used to pad short sequences. + Defaults to ``"0"``. + :param reverse: If True, reverse the non-pad portion of each sequence. + Defaults to False. + :returns: None - class instantiated. + """ + super().__init__() + self._setDefault(seqLen=10) + self._setDefault(embeddingDim=32) + self._setDefault(separator="|") + self._setDefault(sequenceSeparator=",") + self._setDefault(padValue="0") + self._setDefault(reverse=False) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @property + def compatible_dtypes(self) -> Optional[List[DataType]]: + """ + List of compatible data types for the layer. + If the computation can be performed on any data type, return None. + + :returns: List of compatible data types for the layer. + """ + return [StringType()] + + def _transform(self, dataset: DataFrame) -> DataFrame: + """ + Transforms the input dataset. Parses the input string column into a + nested array column with shape ``(seq_len, embedding_dim)``. + + :param dataset: Pyspark dataframe to transform. + :returns: Transformed pyspark dataframe. + """ + if self.getSeparator() == self.getSequenceSeparator(): + raise ValueError("separator and sequenceSeparator must be different.") + + seq_len = self.getSeqLen() + embedding_dim = self.getEmbeddingDim() + pad_value = self.getPadValue() + reverse = self.getReverse() + total_floats = seq_len * embedding_dim + # Build a single regex pattern that matches either delimiter so we can + # split in one pass. + split_pattern = ( + f"[{re.escape(self.getSeparator())}" + f"{re.escape(self.getSequenceSeparator())}]" + ) + + input_datatype = self.get_column_datatype( + dataset=dataset, column_name=self.getInputCol() + ) + + def parse_sequence(x: Column) -> Column: + # Split the input string into flat float tokens. + tokens = F.split(x, pattern=split_pattern) + # Replace empty tokens with the pad value. + tokens = F.transform( + tokens, + lambda t: F.when(t == F.lit(""), pad_value).otherwise(t), + ) + # Truncate to at most ``total_floats`` tokens. + tokens = F.slice(tokens, 1, total_floats) + # Pad with pad_value to exactly ``total_floats`` tokens. + tokens = F.concat( + tokens, + F.array_repeat( + F.lit(pad_value), + F.greatest(F.lit(total_floats) - F.size(tokens), F.lit(0)), + ), + ) + # Cast each token to float. + float_tokens = F.transform(tokens, lambda t: t.cast("float")) + + # Reshape flat array of length seq_len * embedding_dim into a + # nested array of shape (seq_len, embedding_dim). + vectors = F.transform( + F.sequence(F.lit(0), F.lit(seq_len - 1)), + lambda i: F.slice(float_tokens, i * embedding_dim + 1, embedding_dim), + ) + + if not reverse: + return vectors + + # Count the number of non-pad vectors (a vector is pad iff all + # of its components are zero). Reverse only that prefix. + abs_sums = F.transform( + vectors, + lambda v: F.aggregate( + v, + F.lit(0.0), + lambda acc, value: acc + F.abs(value), + ), + ) + non_pad_count = F.aggregate( + abs_sums, + F.lit(0), + lambda acc, s: acc + F.when(s > F.lit(0.0), 1).otherwise(0), + ) + reversed_prefix = F.reverse(F.slice(vectors, 1, non_pad_count)) + suffix = F.slice(vectors, non_pad_count + 1, F.lit(seq_len) - non_pad_count) + return F.concat(reversed_prefix, suffix) + + output_col = single_input_single_output_scalar_transform( + input_col=F.col(self.getInputCol()), + input_col_datatype=input_datatype, + func=parse_sequence, + ) + return dataset.withColumn(self.getOutputCol(), output_col) + + def get_tf_layer(self) -> tf.keras.layers.Layer: + """ + Gets the tensorflow layer for the StringSequenceToEmbedding transformer. + + :returns: Tensorflow keras layer equivalent to this transformer. + """ + return StringSequenceToEmbeddingLayer( + name=self.getLayerName(), + input_dtype=self.getInputTFDtype(), + output_dtype=self.getOutputTFDtype(), + seq_len=self.getSeqLen(), + embedding_dim=self.getEmbeddingDim(), + separator=self.getSeparator(), + sequence_separator=self.getSequenceSeparator(), + pad_value=self.getPadValue(), + reverse=self.getReverse(), + ) diff --git a/src/kamae/tensorflow/layers/__init__.py b/src/kamae/tensorflow/layers/__init__.py index da97195..e0e5143 100644 --- a/src/kamae/tensorflow/layers/__init__.py +++ b/src/kamae/tensorflow/layers/__init__.py @@ -76,6 +76,7 @@ from .string_list_to_string import StringListToStringLayer # noqa: F401 from .string_map import StringMapLayer # noqa: F401 from .string_replace import StringReplaceLayer # noqa: F401 +from .string_sequence_to_embedding import StringSequenceToEmbeddingLayer # noqa: F401 from .string_to_string_list import StringToStringListLayer # noqa: F401 from .sub_string_delim_at_index import SubStringDelimAtIndexLayer # noqa: F401 from .subtract import SubtractLayer # noqa: F401 diff --git a/src/kamae/tensorflow/layers/string_sequence_to_embedding.py b/src/kamae/tensorflow/layers/string_sequence_to_embedding.py new file mode 100644 index 0000000..641aa57 --- /dev/null +++ b/src/kamae/tensorflow/layers/string_sequence_to_embedding.py @@ -0,0 +1,183 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, List, Optional + +import tensorflow as tf + +import kamae +from kamae.tensorflow.typing import Tensor +from kamae.tensorflow.utils import enforce_single_tensor_input + +from .base import BaseLayer + + +@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +class StringSequenceToEmbeddingLayer(BaseLayer): + """ + Parses a delimited string that encodes a sequence of pre-computed + embedding vectors into a dense float tensor. + + Each input element is a single string encoding up to ``seq_len`` + fixed-dimension vectors. Vectors are separated by ``sequence_separator`` + (default ``","``) and floats within a vector are separated by + ``separator`` (default ``"|"``). + + Strings with fewer than ``seq_len * embedding_dim`` floats are padded + with ``pad_value``; strings with more are truncated. Optionally, the + non-pad portion of each sequence can be reversed along the sequence + axis. + + Example: + layer = StringSequenceToEmbeddingLayer(seq_len=4, embedding_dim=3) + x = tf.constant([["1|2|3,4|5|6,0|0|0,0|0|0"]]) + layer(x).shape # (1, 1, 4, 3) + """ + + def __init__( + self, + name: Optional[str] = None, + input_dtype: Optional[str] = None, + output_dtype: Optional[str] = None, + seq_len: int = 10, + embedding_dim: int = 32, + separator: str = "|", + sequence_separator: str = ",", + pad_value: str = "0", + reverse: bool = False, + **kwargs: Any, + ) -> None: + """ + Initialises the StringSequenceToEmbeddingLayer. + + :param name: The name of the layer. Defaults to `None`. + :param input_dtype: The dtype to cast the input to. Defaults to `None`. + :param output_dtype: The dtype to cast the output to. Defaults to `None`. + :param seq_len: Maximum number of vectors per sequence. Defaults to 10. + :param embedding_dim: Dimensionality of each embedding vector. + Defaults to 32. + :param separator: Float separator within a vector. Defaults to ``"|"``. + :param sequence_separator: Separator between vectors. + Defaults to ``","``. + :param pad_value: String used to pad short sequences. Defaults to + ``"0"``. + :param reverse: If True, reverse the non-pad portion of each + sequence along the sequence axis. Defaults to False. + """ + super().__init__( + name=name, input_dtype=input_dtype, output_dtype=output_dtype, **kwargs + ) + if seq_len < 1: + raise ValueError("seq_len must be >= 1.") + if embedding_dim < 1: + raise ValueError("embedding_dim must be >= 1.") + if separator == sequence_separator: + raise ValueError("separator and sequence_separator must be different.") + self.seq_len = seq_len + self.embedding_dim = embedding_dim + self.separator = separator + self.sequence_separator = sequence_separator + self.pad_value = pad_value + self.reverse = reverse + + @property + def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + """ + Returns the compatible dtypes of the layer. + + :returns: The compatible dtypes of the layer. + """ + return [tf.string] + + @enforce_single_tensor_input + def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + """ + Parses each string element into a ``(seq_len, embedding_dim)`` float + matrix. The resulting tensor has the input shape with ``seq_len`` and + ``embedding_dim`` appended as trailing dimensions. If the input has a + trailing size-1 axis, it is dropped so the output is + ``input.shape[:-1] + (seq_len, embedding_dim)``. This matches the + convention used by ``StringToStringListLayer``. + + :param inputs: String tensor of arbitrary shape. + :returns: Float32 tensor with shape + ``input.shape + (seq_len, embedding_dim)`` or, if the input has a + trailing size-1 axis, ``input.shape[:-1] + (seq_len, embedding_dim)``. + """ + input_dynamic_shape = tf.shape(inputs) + input_static_shape = inputs.shape.as_list() + drop_trailing_axis = ( + len(input_static_shape) >= 1 and input_static_shape[-1] == 1 + ) + + flat = tf.reshape(inputs, [-1]) + + # Unify the two separators so a single split yields all floats. + unified = tf.strings.regex_replace( + flat, re.escape(self.separator), self.sequence_separator + ) + + total_floats = self.seq_len * self.embedding_dim + split = tf.strings.split(unified, sep=self.sequence_separator) + dense = split.to_tensor( + default_value=self.pad_value, shape=[None, total_floats] + ) + # Replace any empty tokens (from leading/trailing/repeated separators + # or entirely empty inputs) with the pad value so tf.strings.to_number + # does not fail on the empty string. + dense = tf.where(tf.equal(dense, ""), self.pad_value, dense) + + floats = tf.strings.to_number(dense, out_type=tf.float32) + result = tf.reshape(floats, [-1, self.seq_len, self.embedding_dim]) + + if self.reverse: + # A row is considered padding iff all of its components are 0. + row_norms = tf.reduce_sum(tf.abs(result), axis=-1) + seq_lengths = tf.reduce_sum(tf.cast(row_norms > 0, tf.int32), axis=-1) + result = tf.reverse_sequence(result, seq_lengths, seq_axis=1, batch_axis=0) + + leading_shape = ( + input_dynamic_shape[:-1] if drop_trailing_axis else input_dynamic_shape + ) + output_shape = tf.concat( + [ + leading_shape, + tf.constant( + [self.seq_len, self.embedding_dim], dtype=input_dynamic_shape.dtype + ), + ], + axis=0, + ) + return tf.reshape(result, output_shape) + + def get_config(self) -> Dict[str, Any]: + """ + Gets the configuration of the StringSequenceToEmbedding layer. + Used for saving and loading from a model. + + :returns: Dictionary of the configuration of the layer. + """ + config = super().get_config() + config.update( + { + "seq_len": self.seq_len, + "embedding_dim": self.embedding_dim, + "separator": self.separator, + "sequence_separator": self.sequence_separator, + "pad_value": self.pad_value, + "reverse": self.reverse, + } + ) + return config diff --git a/tests/kamae/spark/transformers/test_string_sequence_to_embedding.py b/tests/kamae/spark/transformers/test_string_sequence_to_embedding.py new file mode 100644 index 0000000..1a3ed21 --- /dev/null +++ b/tests/kamae/spark/transformers/test_string_sequence_to_embedding.py @@ -0,0 +1,217 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import tensorflow as tf + +from kamae.spark.transformers import StringSequenceToEmbeddingTransformer + + +class TestStringSequenceToEmbedding: + @pytest.fixture(scope="class") + def example_dataframe(self, spark_session): + return spark_session.createDataFrame( + [ + ("1|2|3,4|5|6,0|0|0,0|0|0",), + ("7|8|9,1|1|1,0|0|0,0|0|0",), + ("1|2|3",), # short input, requires padding + ("1|2|3,4|5|6,7|8|9,1|1|1,9|9|9",), # long input, requires truncation + ], + ["embedding_str"], + ) + + def test_string_sequence_to_embedding_transform_defaults(self): + transformer = StringSequenceToEmbeddingTransformer() + assert transformer.getSeqLen() == 10 + assert transformer.getEmbeddingDim() == 32 + assert transformer.getSeparator() == "|" + assert transformer.getSequenceSeparator() == "," + assert transformer.getPadValue() == "0" + assert transformer.getReverse() is False + assert transformer.getLayerName() == transformer.uid + assert transformer.getOutputCol() == f"{transformer.uid}__output" + + def test_spark_transform_basic(self, example_dataframe): + transformer = StringSequenceToEmbeddingTransformer( + inputCol="embedding_str", + outputCol="embedding", + seqLen=4, + embeddingDim=3, + ) + actual = transformer.transform(example_dataframe) + rows = actual.select("embedding").collect() + expected = [ + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + [ + [7.0, 8.0, 9.0], + [1.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + [ + [1.0, 2.0, 3.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + [1.0, 1.0, 1.0], + ], + ] + np.testing.assert_allclose( + np.array([r["embedding"] for r in rows]), + np.array(expected), + atol=1e-6, + ) + + def test_spark_transform_reverse(self, example_dataframe): + transformer = StringSequenceToEmbeddingTransformer( + inputCol="embedding_str", + outputCol="embedding", + seqLen=4, + embeddingDim=3, + reverse=True, + ) + actual = transformer.transform(example_dataframe) + rows = actual.select("embedding").collect() + expected = [ + # Reverse only the non-pad prefix (first two vectors). + [ + [4.0, 5.0, 6.0], + [1.0, 2.0, 3.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + [ + [1.0, 1.0, 1.0], + [7.0, 8.0, 9.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + # Single non-pad vector remains unchanged when reversed. + [ + [1.0, 2.0, 3.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + # All four slots filled: full reverse. + [ + [1.0, 1.0, 1.0], + [7.0, 8.0, 9.0], + [4.0, 5.0, 6.0], + [1.0, 2.0, 3.0], + ], + ] + np.testing.assert_allclose( + np.array([r["embedding"] for r in rows]), + np.array(expected), + atol=1e-6, + ) + + @pytest.mark.parametrize( + "input_strings, seq_len, embedding_dim, separator, sequence_separator, pad_value, reverse", + [ + ( + [ + "1|2|3,4|5|6,0|0|0,0|0|0", + "7|8|9,1|1|1,0|0|0,0|0|0", + "1|2|3", + "1|2|3,4|5|6,7|8|9,1|1|1,9|9|9", + ], + 4, + 3, + "|", + ",", + "0", + False, + ), + ( + [ + "1|2|3,4|5|6,0|0|0,0|0|0", + "7|8|9,1|1|1,2|2|2,0|0|0", + ], + 4, + 3, + "|", + ",", + "0", + True, + ), + ( + [ + "1:2:3;4:5:6", + "9:9:9;0:0:0", + ], + 2, + 3, + ":", + ";", + "0", + False, + ), + ], + ) + def test_spark_tf_parity( + self, + spark_session, + input_strings, + seq_len, + embedding_dim, + separator, + sequence_separator, + pad_value, + reverse, + ): + transformer = StringSequenceToEmbeddingTransformer( + inputCol="embedding_str", + outputCol="embedding", + seqLen=seq_len, + embeddingDim=embedding_dim, + separator=separator, + sequenceSeparator=sequence_separator, + padValue=pad_value, + reverse=reverse, + ) + + spark_df = spark_session.createDataFrame( + [(s,) for s in input_strings], ["embedding_str"] + ) + spark_values = np.array( + [ + row["embedding"] + for row in transformer.transform(spark_df).select("embedding").collect() + ] + ) + + # Use input shape (batch, 1); the layer drops the trailing size-1 axis + # so the output is (batch, seq_len, embedding_dim), matching Spark. + tf_input = tf.constant([[s] for s in input_strings]) + tf_values = transformer.get_tf_layer()(tf_input).numpy() + + np.testing.assert_allclose( + spark_values, + tf_values, + atol=1e-6, + err_msg="Spark and TF outputs differ", + ) diff --git a/tests/kamae/tensorflow/layers/test_string_sequence_to_embedding.py b/tests/kamae/tensorflow/layers/test_string_sequence_to_embedding.py new file mode 100644 index 0000000..3b2880b --- /dev/null +++ b/tests/kamae/tensorflow/layers/test_string_sequence_to_embedding.py @@ -0,0 +1,201 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tensorflow as tf + +from kamae.tensorflow.layers import StringSequenceToEmbeddingLayer + + +class TestStringSequenceToEmbedding: + def test_default_separators_drops_trailing_one_axis(self): + layer = StringSequenceToEmbeddingLayer( + name="default_separators", + seq_len=4, + embedding_dim=3, + ) + # Shape (1, 1) with a trailing size-1 axis: expect it to be squeezed. + inputs = tf.constant([["1|2|3,4|5|6,0|0|0,0|0|0"]]) + expected = tf.constant( + [ + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ], + dtype=tf.float32, + ) + output = layer(inputs) + assert output.shape == (1, 4, 3) + tf.debugging.assert_near(expected, output) + + def test_drops_trailing_one_axis_on_rank_three_input(self): + layer = StringSequenceToEmbeddingLayer( + name="drop_trailing_rank_three", + seq_len=4, + embedding_dim=3, + ) + # Shape (None, 1, 1) -> expect (None, 1, 4, 3). + inputs = tf.constant([[["1|2|3,4|5|6,0|0|0,0|0|0"]]]) + assert inputs.shape == (1, 1, 1) + output = layer(inputs) + assert output.shape == (1, 1, 4, 3) + + def test_no_trailing_one_axis_keeps_input_shape(self): + layer = StringSequenceToEmbeddingLayer( + name="no_squeeze", + seq_len=2, + embedding_dim=2, + ) + # Last axis size > 1 -> do NOT drop; output is input.shape + (seq_len, d). + inputs = tf.constant([["1|2,3|4", "5|6,7|8"]]) + assert inputs.shape == (1, 2) + output = layer(inputs) + assert output.shape == (1, 2, 2, 2) + expected = tf.constant( + [ + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + ] + ], + dtype=tf.float32, + ) + tf.debugging.assert_near(expected, output) + + def test_pads_short_sequences_and_truncates_long_ones(self): + layer = StringSequenceToEmbeddingLayer( + name="pad_and_truncate", + seq_len=3, + embedding_dim=2, + pad_value="0", + ) + inputs = tf.constant( + [ + # Short: only 2 vectors supplied, last vector should be pad. + ["1|2,3|4"], + # Long: 4 vectors supplied, last one should be truncated. + ["1|2,3|4,5|6,7|8"], + ] + ) + expected = tf.constant( + [ + [[1.0, 2.0], [3.0, 4.0], [0.0, 0.0]], + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + ], + dtype=tf.float32, + ) + output = layer(inputs) + assert output.shape == (2, 3, 2) + tf.debugging.assert_near(expected, output) + + def test_reverse_reverses_only_non_pad_portion(self): + layer = StringSequenceToEmbeddingLayer( + name="reverse", + seq_len=4, + embedding_dim=2, + reverse=True, + ) + inputs = tf.constant([["1|1,2|2,3|3,0|0"]]) + # Non-pad portion "1|1, 2|2, 3|3" reversed -> "3|3, 2|2, 1|1". + expected = tf.constant( + [ + [ + [3.0, 3.0], + [2.0, 2.0], + [1.0, 1.0], + [0.0, 0.0], + ] + ], + dtype=tf.float32, + ) + output = layer(inputs) + tf.debugging.assert_near(expected, output) + + def test_empty_and_malformed_inputs_do_not_fail(self): + """Inputs with empty cells or leading/repeated separators should not + crash tf.strings.to_number; empty tokens should be treated as pad.""" + layer = StringSequenceToEmbeddingLayer( + name="empty_handling", + seq_len=3, + embedding_dim=2, + pad_value="0", + ) + inputs = tf.constant( + [ + [""], # fully empty cell + [",1|2,3|4"], # leading separator + ["1|2,,3|4"], # repeated separator producing empty token + ["1|2,3|4,"], # trailing separator + ] + ) + expected = tf.constant( + [ + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 1.0], [2.0, 3.0], [4.0, 0.0]], + [[1.0, 2.0], [0.0, 3.0], [4.0, 0.0]], + [[1.0, 2.0], [3.0, 4.0], [0.0, 0.0]], + ], + dtype=tf.float32, + ) + output = layer(inputs) + assert output.shape == (4, 3, 2) + tf.debugging.assert_near(expected, output) + + def test_custom_separators(self): + layer = StringSequenceToEmbeddingLayer( + name="custom_separators", + seq_len=2, + embedding_dim=3, + separator=":", + sequence_separator=";", + ) + inputs = tf.constant([["1:2:3;4:5:6"]]) + expected = tf.constant( + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]], + dtype=tf.float32, + ) + output = layer(inputs) + tf.debugging.assert_near(expected, output) + + def test_get_config_round_trip(self): + layer = StringSequenceToEmbeddingLayer( + name="round_trip", + seq_len=5, + embedding_dim=4, + separator="|", + sequence_separator=",", + pad_value="0", + reverse=True, + ) + config = layer.get_config() + assert config["seq_len"] == 5 + assert config["embedding_dim"] == 4 + assert config["separator"] == "|" + assert config["sequence_separator"] == "," + assert config["pad_value"] == "0" + assert config["reverse"] is True + recovered = StringSequenceToEmbeddingLayer.from_config(config) + assert recovered.seq_len == 5 + assert recovered.embedding_dim == 4 + + def test_invalid_arguments(self): + with pytest.raises(ValueError): + StringSequenceToEmbeddingLayer(seq_len=0, embedding_dim=3) + with pytest.raises(ValueError): + StringSequenceToEmbeddingLayer(seq_len=3, embedding_dim=0) + with pytest.raises(ValueError): + StringSequenceToEmbeddingLayer(separator=",", sequence_separator=",") diff --git a/tests/kamae/tensorflow/test_layer_serialisation.py b/tests/kamae/tensorflow/test_layer_serialisation.py index 40ec92b..e29b2b0 100644 --- a/tests/kamae/tensorflow/test_layer_serialisation.py +++ b/tests/kamae/tensorflow/test_layer_serialisation.py @@ -99,6 +99,7 @@ StringListToStringLayer, StringMapLayer, StringReplaceLayer, + StringSequenceToEmbeddingLayer, StringToStringListLayer, SubStringDelimAtIndexLayer, SubtractLayer, @@ -525,6 +526,19 @@ }, False, ), + ( + StringSequenceToEmbeddingLayer, + [tf.constant("1|2|3,4|5|6,0|0|0,0|0|0", shape=(8, 4))], + { + "seq_len": 4, + "embedding_dim": 3, + "separator": "|", + "sequence_separator": ",", + "pad_value": "0", + "reverse": True, + }, + False, + ), ( StringToStringListLayer, [tf.constant("a", shape=(100, 5))],