Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9b92f11
[Frontend/template] add SDPA modules
student-Jungmin Mar 2, 2026
fc247be
[Template] Add cat & sort template + Multi-output (WIP)
Jagggged Mar 1, 2026
f615178
[Fix] Prevent fallback to eager mode after reaching compilation limit…
student-Jungmin Mar 4, 2026
8ca5d02
[FIX] Add idx_map to the first matmul for logical consistency
student-Jungmin Mar 4, 2026
41288bc
[Template] Polish template kernel of cat operation
YWHyuk Mar 3, 2026
434bbb1
[WIP]
YWHyuk Mar 4, 2026
5295dfb
[Template] Delay def_dma_op codegen
YWHyuk Mar 4, 2026
61caebd
[Template/Cat] Fix apply offset setting
YWHyuk Mar 4, 2026
47684a7
[TOGSim] Add help print
YWHyuk Mar 5, 2026
a24f1f1
[Template/Cat] Limit maximum rank of tile
YWHyuk Mar 5, 2026
4e4300e
[Template/Cat] Refactor cat + Support explicit dram+stride in def_dma_op
YWHyuk Mar 5, 2026
3d9cb38
[Frontend/template] Connect SDPA template to NPU using Torch OpenReg
student-Jungmin Mar 5, 2026
591e8a9
[Templte/Cat] Apply copy operation when node has view
YWHyuk Mar 5, 2026
dab3495
[Refactor] Refactored TopK test code for the OpenReg device
student-Jungmin Mar 7, 2026
a15f5d2
[Template/Sort] Add template code for Bitonic sort
YWHyuk Mar 11, 2026
752cbb8
[Template] Use buffer type instead of hard-coded type
YWHyuk Mar 11, 2026
7af91de
[Frontend] Fix incorrect constant key usage and boolean scientific-no…
HamHyungkyu Mar 12, 2026
7bad17a
[Fix] Refactor MLIR precision handling to be dtype-driven
YWHyuk Mar 11, 2026
fadba78
[Fix] malloc size align + fix origin info
YWHyuk Mar 12, 2026
0189ab9
[TOGSim] Fix local/remote memory stat
YWHyuk Mar 12, 2026
f7f2696
Merge branch 'feat/deepseek' into feature/TopK
YWHyuk Mar 13, 2026
37474cd
Merge pull request #218 from student-Jungmin/feature/TopK
YWHyuk Mar 13, 2026
5268be2
[Frontend/template] add SPDA decode GQA template imlementation
HamHyungkyu Mar 12, 2026
59bd8f8
WIP
YWHyuk Mar 12, 2026
bfc2b22
[Frontend/template] SPDA implementation debug
HamHyungkyu Mar 13, 2026
ce93306
[Template/SPDA] Remove subtile size temporarily
YWHyuk Mar 13, 2026
f2717e1
[Template/SPDA] minor fix
YWHyuk Mar 13, 2026
be23638
[Cleanup] Unflag debug option
YWHyuk Mar 16, 2026
e925ae4
[CI] Add deepseek test case
YWHyuk Mar 16, 2026
db85991
[Template/SPDA] Cleanup test case + Add an activate option
YWHyuk Mar 17, 2026
dd71c70
[Frontend] Handle RecompileSignal in MLIRKernel code generation
HamHyungkyu Mar 17, 2026
c5f085e
[Frontend] Enhance vector size handling for low-precision paths in ML…
HamHyungkyu Mar 17, 2026
fdd5b54
[Refactor] move to TOGSimulator-based scheduler API
YWHyuk Mar 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/pytorchsim_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,27 @@ jobs:
-e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \
${{ inputs.image_name }} python3 PyTorchSim/tests/test_conv2d.py

test_cat:
name: Run test_cat.py
runs-on: self-hosted
steps:
- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Run test_cat.py
run: |
echo "Running test_cat.py"
docker run --rm \
-v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \
-e TORCHSIM_DUMP_PATH=/dump \
-e vpu_num_lanes="${{ inputs.vector_lane }}" \
-e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \
${{ inputs.image_name }} python3 PyTorchSim/tests/test_cat.py

test_matmul:
name: Run test_matmul.py
runs-on: self-hosted
Expand Down Expand Up @@ -705,6 +726,27 @@ jobs:
-e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \
${{ inputs.image_name }} python3 PyTorchSim/tests/Yolov5/test_yolov5.py

test_deepseek:
name: Run test_deepseek
runs-on: self-hosted
steps:
- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Run test_deepseek_v3_base.py
run: |
echo "Running test_deepseek_v3_base.py"
docker run --rm \
-v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \
-e TORCHSIM_DUMP_PATH=/dump \
-e vpu_num_lanes="${{ inputs.vector_lane }}" \
-e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \
${{ inputs.image_name }} python3 PyTorchSim/tests/DeepSeek/test_deepseek_v3_base.py

test_accuracy:
name: Run test_accuracy
runs-on: self-hosted
Expand Down
4 changes: 2 additions & 2 deletions AsmParser/tog_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class tog_generator:
StonneTraceCompute= 6
StonneTraceLoad = 7
StonneTraceStore = 8
def __init__(self, origins="Unknown") -> None:
def __init__(self, origins={"Unknown"}) -> None:
self.module_name = "tile_operation_graph"
self.module = None
self.raw_graph = {}
Expand Down Expand Up @@ -226,7 +226,7 @@ def generate_tile_graph(self, name="tile_graph", cycle_list=list, x_offset=int,
offset = w_offset if is_preload else x_offset
iter_node.torchsim_overlapping_cycle = max(iter_node.torchsim_cycle - offset, 0)

origin_info = "_".join(map(str, self.origins))
origin_info = self.origins if isinstance(self.origins, str) else "_".join(map(str, self.origins))
onnx_node_list = [node.to_onnx() for node in node_list] # Exclude root node
dump_onnx_graph(name, onnx_node_list, vector_lane, origin_info, stonneGraph=stonneGraph)

Expand Down
34 changes: 1 addition & 33 deletions PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/native/CPUFallback.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/transformers/attention.h>

#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
#include <torch/library.h>
Expand Down Expand Up @@ -40,36 +41,6 @@ void wrapper_quantize_tensor_per_tensor_affine_stub(
rtensor, qtensor, scale, zero_point);
}

std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
wrapper__scaled_dot_product_fused_attention_overrideable(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
return at::native::openreg::_scaled_dot_product_fused_attention_overrideable(
query,
key,
value,
attn_bias,
dropout_p,
is_causal,
return_debug_mask,
scale);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
wrapper_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor& grad_out,
Expand Down Expand Up @@ -172,9 +143,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("abs.out", &wrapper_abs_out);
m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor);
m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice);
m.impl(
"_scaled_dot_product_fused_attention_overrideable",
&wrapper__scaled_dot_product_fused_attention_overrideable);
m.impl(
"_scaled_dot_product_fused_attention_overrideable_backward",
&wrapper_scaled_dot_product_fused_attention_overrideable_backward);
Expand Down
83 changes: 33 additions & 50 deletions PyTorchSimDevice/csrc/aten/native/Extra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,39 @@ int64_t _fused_sdp_choice(
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
auto backend = sdp::SDPBackend::math;
return static_cast<int64_t>(backend);

sdp::sdp_params params{query, key, value, attn_mask, dropout_p, is_causal, enable_gqa};

// Reject inputs that are fundamentally unsupported (e.g. wrong rank)
if (!sdp::check_tensor_shapes(params, /*debug=*/false)) {
return static_cast<int64_t>(sdp::SDPBackend::error);
}

// q: (B, Hq, L, E) k/v: (B, H, S, E)
const int64_t Hq = query.size(-3);
const int64_t H = key.size(-3);
const int64_t L = query.size(-2); // query sequence length
const int64_t S = key.size(-2); // key/value sequence length

// Conditions required by the MLIR FlashSDPA kernel:
// Prefill only : L == S (decode has L == 1, not supported)
// Non-GQA : Hq == H (equal query and KV heads)
// No dropout : template has no dropout implementation
// Dense tensors : no nested tensor support
const bool can_use_mlir_flash =
(L == S) &&
(Hq == H) && !enable_gqa &&
sdp::check_for_dropout(params, /*debug=*/false) &&
sdp::check_nested_tensor(params, /*debug=*/false);

const bool ctx_flash = at::globalContext().userEnabledFlashSDP();
const bool ctx_math = at::globalContext().userEnabledMathSDP();

if (ctx_flash && can_use_mlir_flash) {
return static_cast<int64_t>(sdp::SDPBackend::overrideable);
}

return static_cast<int64_t>(sdp::SDPBackend::math);
}

void quantize_tensor_per_tensor_affine_stub(
Expand All @@ -29,54 +60,6 @@ void quantize_tensor_per_tensor_affine_stub(
double scale,
int64_t zero_point) {}

std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
_scaled_dot_product_fused_attention_overrideable(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_q = query.size(2);
const int64_t max_seqlen_kv = key.size(2);

auto opts = query.options();
auto output =
at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
auto logsumexp =
at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
auto debug_attn_mask = at::empty(
{batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
opts.dtype(at::kFloat));
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));

return std::make_tuple(
output,
logsumexp,
at::Tensor(),
at::Tensor(),
max_seqlen_q,
max_seqlen_kv,
philox_seed,
philox_offset,
debug_attn_mask);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor& grad_out,
Expand Down
12 changes: 11 additions & 1 deletion PyTorchSimDevice/torch_openreg/openreg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class device:

def __init__(self, device):
self.idx = torch.accelerator._get_device_index(device, optional=True)
self.prev_idx = -1
self.prev_idx = -1

def __enter__(self):
self.prev_idx = torch_openreg._C._exchangeDevice(self.idx)
Expand Down Expand Up @@ -64,10 +64,20 @@ def _lazy_init():
global _initialized, _tog_simulator
if is_initialized():
return

# Replace the global C++ binding with our custom dispatcher patch
# from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention
# torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention

torch_openreg._C._init()
register_interface_for_device(custom_device(), ExtensionDeviceInterface)
_initialized = True

# Set default SDPA backend to math-only for this device.
torch._C._set_sdp_use_flash(False)
torch._C._set_sdp_use_overrideable(False)
torch._C._set_sdp_use_math(True)

# Create default streams for all devices
num_devices = device_count()
for device_idx in range(num_devices):
Expand Down
18 changes: 15 additions & 3 deletions PyTorchSimFrontend/extension_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,17 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256):
f"""
{extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \
-relocation-model=pic -march=riscv64 -O3 --stack-size-section \
-mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \
-mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \
-filetype=obj \
{'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \
-O2 {filename}.ll -o {filename}.o
""",
).strip(),
re.sub(r"[ \n]+", " ",
f"""
{extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \
-relocation-model=pic -march=riscv64 -O3 --stack-size-section \
-mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \
-O2 {filename}.ll -o {filename}.s
""",
).strip()]
Expand Down Expand Up @@ -109,9 +118,10 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si
f"""
{extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \
-relocation-model=pic -march=riscv64 -O3 --stack-size-section \
-mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \
-mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \
-filetype=obj \
{'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \
-O2 {sample_filename}.ll -o {sample_filename}.s
-O2 {sample_filename}.ll -o {sample_filename}.o
""",
).strip()]

Expand Down Expand Up @@ -166,11 +176,13 @@ def load(cls, source_code,
opt_cmd = shlex.split(cmds[0])
translate_cmd = shlex.split(cmds[1])
llc_cmd = shlex.split(cmds[2])
llc_asm_cmd = shlex.split(cmds[3])
with lock:
try:
subprocess.check_call(opt_cmd)
subprocess.check_call(translate_cmd)
subprocess.check_call(llc_cmd)
subprocess.check_call(llc_asm_cmd)
except subprocess.CalledProcessError as e:
logger.error(f"Command failed with exit code {e.returncode}")
logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}")
Expand Down
2 changes: 0 additions & 2 deletions PyTorchSimFrontend/extension_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def __getattr__(name):
"spad_size" : config_yaml["vpu_spad_size_kb_per_lane"] << 10 # Note: spad size per lane
}

if name == "CONFIG_PRECISION":
return 4 # 32bit
if name == "CONFIG_NUM_CORES":
return config_yaml["num_cores"]
if name == "vpu_vector_length_bits":
Expand Down
2 changes: 1 addition & 1 deletion PyTorchSimFrontend/mlir/mlir_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def cached_run_fn(*args, **kwargs):
self.source_code, vectorlane_size=self.extra_args["vector_lane"],
loop_size=None, spad_info=self.extra_args["spad_info"],
vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"],
origins="Unknown", silent_mode=True,
origins=self.extra_args["origins"], silent_mode=True,
autotune=self.extra_args['autotune'])

args = [
Expand Down
Loading
Loading