Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions chispa/bcolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ class bcolors:
LightCyan = '\033[36m'
White = '\033[97m'

# Style
Bold = '\033[1m'
Underline = '\033[4m'


def blue(s: str) -> str:
return bcolors.LightBlue + str(s) + bcolors.LightRed
Expand Down
72 changes: 58 additions & 14 deletions chispa/rows_comparer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
import chispa.six as six
from chispa.prettytable import PrettyTable
from chispa.bcolors import *
from chispa.terminal_str_formatter import format_string, format_mismatched_cell
import chispa
from pyspark.sql.types import Row
from typing import List


def assert_basic_rows_equality(rows1, rows2, underline_cells=False):
if underline_cells:
row_column_names = rows1[0].__fields__
num_columns = len(row_column_names)
# {
# "mismatched_rows": ["red", "bold"],
# "matched_rows": "blue",
# "mismatched_cells": ["white", "underline"],
# "matched_cells": ["blue", "bold"]
# }


def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats={
"mismatched_rows": ["red", "bold"],
"matched_rows": ["blue"],
"mismatched_cells": ["white", "underline"],
"matched_cells": ["blue", "bold"]
}):
if rows1 != rows2:
t = PrettyTable(["df1", "df2"])
zipped = list(six.moves.zip_longest(rows1, rows2))
for r1, r2 in zipped:
if r1 == r2:
t.add_row([blue(r1), blue(r2)])
t.add_row([format_string(r1, formats["matched_rows"]), format_string(r2, formats["matched_rows"])])
else:
if underline_cells:
t.add_row(__underline_cells_in_row(
r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns))
else:
t.add_row([r1, r2])
r = detailed_row_formatter(r1=r1, r2=r2, formats=formats)
t.add_row(r)
raise chispa.DataFramesNotEqualError("\n" + t.get_string())


Expand Down Expand Up @@ -71,10 +79,8 @@ def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_col
append_str = ", "

if r1[column] != r2[column]:
r1_string += underline_text(
f"{column}='{r1[column]}'") + f"{append_str}"
r2_string += underline_text(
f"{column}='{r2[column]}'") + f"{append_str}"
r1_string += underline_text(f"{column}='{r1[column]}'") + f"{append_str}"
r2_string += underline_text(f"{column}='{r2[column]}'") + f"{append_str}"
else:
r1_string += f"{column}='{r1[column]}'{append_str}"
r2_string += f"{column}='{r2[column]}'{append_str}"
Expand All @@ -83,3 +89,41 @@ def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_col
r2_string += ")"

return [bcolors.LightRed + r1_string, r2_string]


def detailed_row_formatter(
r1,
r2,
formats) -> List[str]:
# row1_column_names = r1[0].__fields__
# row2_column_names = r1[0].__fields__
# num_columns1 = len(row1_column_names)
# num_columns2 = len(row2_column_names)
r1_string = ""
r2_string = ""
has_a_mismatched_cell = False
zipped = list(six.moves.zip_longest(row1_column_names, row2_column_names))
for column in zipped:
# for index, column in enumerate(row_column_names):
# if ((index+1) == num_columns):
# append_str = ""
# else:
# append_str = ", "
append_str = "&"

if r1[column] != r2[column]:
has_a_mismatched_cell = True
r1_string += format_string(f"{column}='{r1[column]}'", formats["mismatched_cells"]) + f"{append_str}"
r2_string += format_string(f"{column}='{r2[column]}'", formats["mismatched_cells"]) + f"{append_str}"
else:
r1_string += format_string(f"{column}='{r1[column]}'{append_str}", formats["matched_cells"])
r2_string += format_string(f"{column}='{r2[column]}'{append_str}", formats["matched_cells"])

if has_a_mismatched_cell:
r1_string = format_string("Row(", formats["mismatched_rows"]) + r1_string + format_string(")", formats["mismatched_rows"])
r2_string = format_string("Row(", formats["mismatched_rows"]) + r2_string + format_string(")", formats["mismatched_rows"])
else:
r1_string = format_string("Row(", formats["matched_rows"]) + r1_string + format_string(")", formats["matched_rows"])
r2_string = format_string("Row(", formats["matched_rows"]) + r2_string + format_string(")", formats["matched_rows"])

return [r1_string, r2_string]
20 changes: 20 additions & 0 deletions chispa/terminal_str_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
def format_string(input, formats):
formatting = {
"nc": '\033[0m', # No Color, reset all
"bold": '\033[1m',
"underline": '\033[4m',
"blink": '\033[5m',
"inverted": '\033[7m',
"hidden": '\033[8m',
"blue": '\033[34m',
"white": '\033[97m',
"red": '\033[31m',
}
formatted = input
for format in formats:
s = formatting[format]
formatted = s + str(formatted) + s
return formatting["nc"] + str(formatted) + formatting["nc"]

def format_mismatched_cell(input_text: str, mismatched_cells) -> str:
return format_string(input_text, mismatched_cells)
5 changes: 3 additions & 2 deletions tests/test_readme_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def test_remove_non_word_characters_long_error():
(None, None)
]
expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"])
with pytest.raises(DataFramesNotEqualError) as e_info:
assert_df_equality(actual_df, expected_df)
assert_df_equality(actual_df, expected_df)
# with pytest.raises(DataFramesNotEqualError) as e_info:
# assert_df_equality(actual_df, expected_df)


def ignore_row_order():
Expand Down
7 changes: 7 additions & 0 deletions tests/test_terminal_str_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest

from chispa.terminal_str_formatter import format_string


def test_it_can_make_a_blue_string():
print(format_string("hi", ["bold", "blink"]))