Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 6 additions & 34 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -1201,8 +1201,8 @@ pipeline {
description: "Run the ck_tile FMHA tests (default: OFF)")
booleanParam(
name: "RUN_TILE_ENGINE_BASIC_TESTS",
defaultValue: false,
description: "Run the tile_engine_basic tests (default: OFF)")
defaultValue: true,
description: "Run the tile_engine_basic tests (default: ON)")
booleanParam(
name: "RUN_TILE_ENGINE_GEMM_TESTS",
defaultValue: false,
Expand Down Expand Up @@ -1650,7 +1650,10 @@ pipeline {
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
-D GEMM_PRESHUFFLE_CONFIG_FILE="default_ci_config.json" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all """
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
Expand All @@ -1667,37 +1670,6 @@ pipeline {
}
parallel
{
stage("Run TILE_ENGINE_GEMM Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx90a" \
-D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
-D GEMM_STREAMK_DATATYPE="fp8;fp16" \
-D GEMM_STREAMK_LAYOUT="rcr" \
-D GEMM_MULTI_D_DATATYPE="fp16" \
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run TILE_ENGINE_GEMM Tests on gfx942")
{
when {
Expand Down
52 changes: 14 additions & 38 deletions tile_engine/ops/gemm/gemm_instance_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,40 +643,31 @@ def populate_launch(

using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""

# Runfunction body
instance_code += """

const auto Run = [&](const auto memory_operation_) {"""

# Scheduler initialization
if self.kernel_name_prefix in ["gemm_universal"]:
instance_code += f"""
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};"""

# Memory operation
instance_code += """
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;"""
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};"""

# UniversalGemmProblem
if self.kernel_name_prefix in ["gemm_universal"]:
instance_code += """

using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC,
UseStructuredSparsity, UsePersistentKernel,
NumWaveGroups, Preshuffle>,
scheduler>;"""
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC,
UseStructuredSparsity, UsePersistentKernel,
NumWaveGroups, Preshuffle>,
scheduler>;"""

# GemmPipeline
if self.kernel_name_prefix in ["gemm_universal"]:
instance_code += f"""

using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""

# Epilogue
instance_code += self.populate_epilogue(epilogue)
Expand Down Expand Up @@ -748,23 +739,8 @@ def populate_launch(
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));

return ave_time;
}};"""

# Run SplitK handler

instance_code += """

float ave_time = 0.f;
if(args.k_batch == 1) {
ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
} else {
ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
return ave_time;
}
};
}}
}};
"""
return instance_code

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
"trait_config": {
"pipeline": {
"values": [
"compv4"
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
Expand All @@ -60,7 +62,8 @@
},
"epilogue": {
"values": [
"cshuffle"
"cshuffle",
"default"
]
},
"pad_m": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
"trait_config": {
"pipeline": {
"values": [
"compv4"
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
Expand All @@ -60,7 +62,8 @@
},
"epilogue": {
"values": [
"cshuffle"
"cshuffle",
"default"
]
},
"pad_m": {
Expand Down