-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrz_linear_kernel.cu
More file actions
669 lines (577 loc) · 24.5 KB
/
rz_linear_kernel.cu
File metadata and controls
669 lines (577 loc) · 24.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/core/TensorAccessor.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <thrust/execution_policy.h>
#include <thrust/unique.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/copy.h>
#include <thrust/host_vector.h>
#include <vector>
#include <stdio.h>
#define MAX_GRID_SIZE 8192
#define MAX_BLOCK_SIZE 128
#define SMLS 16 // BLOCK in tiled multiplication
#define SMLSMASK 15L
#define TTILE 4 // per thread responsibility TTILE x TTILE
#define SMLST 64 // SMLS * TTILE
#define SHIFT 1
#define BITMASK 1048575L // 2^20 -1 for cases when range is < 10^6
#define BASIC 0
#define TILED 1
__device__ int64_t hash_func(int64_t a, int64_t b, const torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers) {
return (a * random_numbers[3] + b * random_numbers[2] + random_numbers[1]) % random_numbers[0]; // modulo with large numbers is expensive
//return (a * random_numbers[3] + b * random_numbers[2] + random_numbers[1]) & BITMASK; // TODO
}
__device__ int64_t hash_func3(int64_t a, int64_t b, int64_t c, const torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers) {
return (a * random_numbers[3] + b * random_numbers[2] + c* random_numbers[1] + random_numbers[4]) % random_numbers[0];
}
__device__ int64_t hash_func4(int64_t a, int64_t b, int64_t c, int64_t d, const torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers) {
return (a * random_numbers[3] + b * random_numbers[2] + c* random_numbers[1] + d * random_numbers[4] + random_numbers[5]) % random_numbers[0];
//return (a * random_numbers[3] + b * random_numbers[2] + c* random_numbers[1] + d * random_numbers[4] + random_numbers[5]) & (int64_t) BITMASK;
}
inline __device__ int64_t location(int64_t i, int64_t j, int chunk_size, const torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers, int64_t range) {
// we have chunked columwise for faster forward pass
int64_t chunk_id = i / chunk_size;
int64_t offset = i % chunk_size;
return (hash_func(chunk_id, j, random_numbers) + offset) % range;
}
inline __device__ int64_t location_tiled(int64_t i, int64_t j, const torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers, int64_t range) {
// we have chunked columwise for faster forward pass
int64_t block_x = i / SMLS;
int64_t block_y = j / SMLS;
int64_t ix = i & SMLSMASK;
int64_t iy = j & SMLSMASK;
int64_t loc = (hash_func(block_x, block_y, random_numbers)) % (range - SMLS * SMLS + 1) + ix * SMLS + iy;
return loc;
}
inline __device__ int64_t location_tile(int64_t tile_i, int64_t tile_j, int i, int j, int chunk_size,
const torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers, int64_t range) {
int id = j * SMLS + i;
int chunk_id = id / chunk_size;
int offset = id % chunk_size;
return (hash_func3(tile_i, tile_j, chunk_id, random_numbers) + offset) % range;
}
template<typename scalar_t>
__global__ void rz_linear_forward_cuda_kernel(
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> hashed_weights,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> output,
torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers,
int batch,
int input_dim,
int output_dim,
int chunk_size,
int hashed_weight_size
)
{
int out_x = blockIdx.x;
int out_y = threadIdx.y;
scalar_t val = 0;
int num_chunks = (input_dim + chunk_size - 1)/ chunk_size;
int idx = 0;
int kidx =0;
for (; out_x < batch; out_x+= gridDim.x) {
for(; out_y < output_dim; out_y += blockDim.y) {
val = 0;
for(int c = 0; c < num_chunks;c ++) {
idx = hash_func(c, out_y, random_numbers) % (hashed_weight_size); //
for( int ic = 0; ic < chunk_size ; ic ++) {
kidx = c * chunk_size + ic;
if (kidx < input_dim) {
val+= input[out_x][kidx] * hashed_weights[idx];
idx = (idx + 1) % hashed_weight_size;
}
}
}
output[out_x][out_y] = val;
}
}
}
template<typename scalar_t>
__global__ void rz_linear_backward_cuda_kernel_input(
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> hashed_weights,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> out_grad,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input_grad,
torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers,
int batch,
int input_dim,
int output_dim,
int chunk_size,
int64_t hashed_weight_size,
bool tiled
)
{
scalar_t val = 0;
int num_chunks = (input_dim + chunk_size - 1)/ chunk_size;
int64_t idx = 0;
#pragma unroll
for (int in_x = blockIdx.x; in_x < batch; in_x+= gridDim.x) {
#pragma unroll
for(int in_y = threadIdx.y; in_y < input_dim; in_y += blockDim.y) {
#pragma unroll
for(int k=0; k< output_dim;k++) {
if (tiled) {
idx = location_tiled(in_y, k, random_numbers, hashed_weight_size);
//if (in_y % 16 ==0 && k % 16 == 0)
// printf("[i ]location_tiled: %ld %ld %ld\n", (int64_t) in_y, (int64_t) k, idx);
}
else {
idx = location(in_y, k, chunk_size, random_numbers, hashed_weight_size);
}
input_grad[in_x][in_y] += hashed_weights[idx] * out_grad[in_x][k];
}
}
}
}
template<typename scalar_t>
__global__ void rz_linear_backward_cuda_kernel_weight(
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> hashed_weights,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> out_grad,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> weight_grad,
torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers,
int batch,
int input_dim,
int output_dim,
int chunk_size,
int64_t hashed_weight_size,
bool tiled
)
{
scalar_t val = 0;
int num_chunks = (input_dim + chunk_size - 1)/ chunk_size;
int64_t loc = 0;
//printf("%d %d (%d, %d)\n", wt_x, wt_y, input_dim, output_dim);
#pragma unroll
for (int wt_x = blockIdx.x; wt_x < input_dim; wt_x+= gridDim.x) {
#pragma unroll
for(int wt_y = threadIdx.y; wt_y < output_dim; wt_y += blockDim.y) {
val = 0;
#pragma unroll
for(int k=0;k< batch;k++) {
val += input[k][wt_x] * out_grad[k][wt_y];
}
// multiple threads will write to this.
if (tiled) {
loc = location_tiled(wt_x, wt_y, random_numbers, hashed_weight_size);
//if (wt_x % 16 ==0 && wt_y % 16 == 0)
//printf("[wt]location_tiled: %ld %ld %ld\n", (int64_t)wt_x, (int64_t)wt_y, loc);
} else {
loc = location(wt_x, wt_y, chunk_size, random_numbers, hashed_weight_size);
}
atomicAdd(& weight_grad[loc], val);
//printf("%d %d %d adding %.4f\n", wt_x, wt_y, loc, val);
}
}
}
torch::Tensor rz_linear_forward_cuda (
const torch::Tensor& hashed_weights, // 1 x n
const torch::Tensor& input, // b x d1
const torch::Tensor& random_numbers,
int input_dim,
int output_dim,
int chunk_size
)
{
int64_t hashedWeightSize = hashed_weights.size(0);
auto output = at::empty({input.size(0), output_dim}, input.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.device().index());
int x_max = input.size(0);
int y_max = output_dim;
dim3 block = dim3(1, MAX_BLOCK_SIZE, 1);
if (y_max < MAX_BLOCK_SIZE) {
block = dim3(1, y_max, 1);
}
dim3 grid = dim3(MAX_GRID_SIZE, 1, 1);
if ( x_max < MAX_GRID_SIZE) {
grid = dim3(x_max, 1, 1);
}
AT_DISPATCH_FLOATING_TYPES(hashed_weights.type(), "rz_linear_forward_cuda", ([&] {
rz_linear_forward_cuda_kernel<scalar_t><<<grid, block, 0, stream>>>(
hashed_weights.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
input.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
random_numbers.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
input.size(0),
input_dim,
output_dim,
chunk_size,
hashed_weights.size(0)
);
}));
return output;
}
template<typename scalar_t>
__global__ void rz_linear_forward_cuda_kernel_tiled(
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> hashed_weights,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> output,
torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers,
int M,
int K,
int N,
int chunk_size,
int hashed_weight_size
)
{
int num_block_width = (N + SMLST - 1) / SMLST;
int total_number_of_blocks = (int)((M + SMLST - 1) / SMLST) * (int)((N + SMLST - 1) / SMLST) ;
int block_idx = blockIdx.x;
int block_x, block_y, tx, ty, gx, gy;
int64_t idx;
__shared__ float shareI[SMLST][SMLST + SHIFT];
__shared__ float shareM[SMLST][SMLST + SHIFT];
float val[TTILE][TTILE] = {0};
#pragma unroll
for (; block_idx < total_number_of_blocks; block_idx += gridDim.x) { // outer loop if the output matrix is too large
block_x = block_idx / num_block_width;
block_y = block_idx % num_block_width;
tx = threadIdx.x; ty = threadIdx.y;
gx = block_x * SMLST + tx;
gy = block_y * SMLST + ty;
#pragma unroll
for (int x_offset = 0; x_offset < TTILE; x_offset ++) {
#pragma unroll
for (int y_offset = 0; y_offset < TTILE ; y_offset ++) {
val[x_offset][y_offset] = 0.;
}
}
#pragma unroll
for (int i = 0 ; i < (K + SMLST - 1) / SMLST ; i ++) {
#pragma unroll
for (int x_offset_abs = 0; x_offset_abs < SMLST ; x_offset_abs += SMLS) {
#pragma unroll
for (int y_offset_abs = 0; y_offset_abs < SMLST ; y_offset_abs += SMLS) {
// SMLSxSMLS = (block_x, block_y, x_offset_abs/SMLS, y_offset_abs / SMLS)
idx = location_tiled(i*SMLST + x_offset_abs, block_y * SMLST + y_offset_abs, random_numbers, hashed_weight_size);
//if ( tx == 0 && ty == 0) {
// printf("location_tiled: %ld %ld %ld\n", (int64_t)(i*SMLST + x_offset_abs),(int64_t)( block_y * SMLST + y_offset_abs), idx);
//}
if (i*SMLST + ty + y_offset_abs < K && gx + x_offset_abs < M) {
shareI[tx + x_offset_abs][ty + y_offset_abs] = input[(gx + x_offset_abs)][i* SMLST + ty + y_offset_abs]; // row major (gx, i*SMLS+ty)
} else {
shareI[tx + x_offset_abs][ty + y_offset_abs] = 0.;
}
if (i*SMLST + tx + x_offset_abs < K && gy + y_offset_abs < N) {
//shareM[ty + y_offset_abs][tx + x_offset_abs]= weights[i* SMLST + (tx + x_offset_abs) + (gy + y_offset_abs) * K]; // coumn major (i*SMLS+tx, gy)
shareM[ty + y_offset_abs][tx + x_offset_abs]= hashed_weights[idx + tx * SMLS + ty]; // coumn major (i*SMLS+tx, gy)
//shareM[ty + y_offset_abs][tx + x_offset_abs]= hashed_weights[location_tiled(i*SMLST + (tx + x_offset_abs), gy + y_offset_abs, random_numbers, hashed_weight_size)];
} else {
shareM[ty + y_offset_abs][tx + x_offset_abs] = 0.;
}
}
}
__syncthreads();
#pragma unroll
for (int x_offset = 0; x_offset < TTILE; x_offset ++) {
#pragma unroll
for (int y_offset = 0; y_offset < TTILE ; y_offset ++) {
#pragma unroll
for (int j = 0; j < SMLST; j ++ ) {
val[x_offset][y_offset] += shareI[tx + x_offset*SMLS][j] * shareM[ty + y_offset*SMLS][j];
}
}
}
__syncthreads();
}
#pragma unroll
for (int x_offset = 0; x_offset < TTILE; x_offset ++) {
#pragma unroll
for (int y_offset = 0; y_offset < TTILE ; y_offset ++) {
if ((gx + x_offset * SMLS) < M && (gy + y_offset * SMLS) < N) {
output[(gx + x_offset*SMLS)][(gy + y_offset * SMLS)] = val[x_offset][y_offset];
}
}
}
}
}
template<typename scalar_t>
__global__ void rz_linear_forward_cuda_kernelXXX(
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> hashed_weights,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> output,
torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers,
int batch,
int input_dim,
int output_dim,
int chunk_size,
int hashed_weight_size
)
{
/* strategy is that each block is responsible for SIDE x SIDE chunk of output. */
int total_output_blocks_x = ((batch + SMLS - 1) / SMLS);
int total_output_blocks_y = ((output_dim + SMLS - 1)/ SMLS);
int total_interim_blocks_k = ((input_dim + SMLS - 1)/ SMLS);
int total_output_blocks = total_output_blocks_x * total_output_blocks_y;
int tid = threadIdx.x;
int bid = blockIdx.x;
int block_x;
int block_y;
int i_block_x;
int i_block_y;
int loc;
int r,ir,ix,iy;
scalar_t val;
__shared__ scalar_t local_output [SMLS][SMLS]; // TODO +1 is the padding so that two rows do not belong to exaclty same memory banks.
__shared__ scalar_t local_input [SMLS][SMLS];
__shared__ scalar_t local_weights [SMLS][SMLS];
for (int oblock = bid; oblock < total_output_blocks; oblock += gridDim.x) {
block_x = oblock % total_output_blocks_x;
block_y = oblock / total_output_blocks_x;
// set hte local_output to 0
for( int itid = tid; itid < SMLS * SMLS; itid += blockDim.x) {
// keep warp in row
i_block_x = itid / SMLS;
i_block_y = itid % SMLS;
local_output[i_block_x][i_block_y] = 0;
}
// block_x, block_y is the block coordinates
for (int interim = 0; interim < total_interim_blocks_k; interim ++ ) {
/* we will now load the input chunk of size SIDE x SIDE into local memory*/
// copy block_x, interim from input // local_input
for( int itid = tid; itid < SMLS * SMLS; itid += blockDim.x) {
// keep warp in row
i_block_x = itid / SMLS;
i_block_y = itid % SMLS;
if (block_x * SMLS + i_block_x < batch && interim * SMLS + i_block_y < input_dim) {
local_input[i_block_x][i_block_y] = input[block_x * SMLS + i_block_x][interim * SMLS + i_block_y];
} else {
local_input[i_block_x][i_block_y] = 0;
}
}
// copy interim, block_y from weight
// we will hash the SxS tile coordinates // local_weights
for( int itid = tid; itid < SMLS * SMLS; itid += blockDim.x) {
i_block_x = itid % SMLS;
i_block_y = itid / SMLS;
if (interim * SMLS + i_block_x < input_dim && block_y * SMLS + i_block_y < output_dim) {
// stored in column major order
loc = location_tile(interim, block_y, i_block_x, i_block_y, chunk_size, random_numbers, hashed_weight_size);
local_weights[i_block_y][i_block_x] = hashed_weights[loc];
} else {
local_weights[i_block_y][i_block_x] = 0;
}
}
// local matrix multiplication now
for( int itid = tid; itid < SMLS * SMLS; itid += blockDim.x) {
r = itid/SMLS;
ir = itid % SMLS;
ix = (ir <= r) ? r - ir : (SMLS - (ir - r));
iy = ir;
// ix,iy is the local_output we will compute now
val = 0;
#pragma unroll
for(int k = 0; k < SMLS; k++) {
val += local_input[ix][k] * local_weights[iy][k]; // column-major
}
local_output[ix][iy] += val;
}
}
// now push the local_output into the global memory now
for( int itid = tid; itid < SMLS * SMLS; itid += blockDim.x) {
// keep warp in row
i_block_x = itid / SMLS;
i_block_y = itid % SMLS;
if (block_x * SMLS + i_block_x < batch && block_y * SMLS + i_block_y < output_dim) {
output[block_x * SMLS + i_block_x][block_y * SMLS + i_block_y] = local_output[i_block_x][i_block_y];
}
}
}
}
torch::Tensor rz_linear_forward_cuda_tiled (
const torch::Tensor& hashed_weights, // 1 x n
const torch::Tensor& input, // b x d1
const torch::Tensor& random_numbers,
int input_dim,
int output_dim,
int chunk_size
)
{
int M = input.size(0);
int K = input_dim;
int N = output_dim;
int64_t hashedWeightSize = hashed_weights.size(0);
auto output = at::empty({input.size(0), output_dim}, input.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.device().index());
dim3 block = dim3(SMLS, SMLS, 1);
int total_number_of_blocks = (int)((M + SMLS - 1) / SMLS) * (int)((N + SMLS - 1) / SMLS) ;
int grid = MAX_GRID_SIZE;
if (total_number_of_blocks < MAX_GRID_SIZE) {
grid = total_number_of_blocks;
}
AT_DISPATCH_FLOATING_TYPES(hashed_weights.type(), "rz_linear_forward_cuda", ([&] {
rz_linear_forward_cuda_kernel_tiled<scalar_t><<<grid, block, 0, stream>>>(
hashed_weights.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
input.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
random_numbers.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
M,
K,
N,
chunk_size,
hashed_weights.size(0)
);
}));
return output;
}
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor rz_linear_forward(
const torch::Tensor& hashed_weights,
const torch::Tensor& input,
const torch::Tensor& random_numbers,
int input_dim,
int output_dim,
int chunk_size,
bool tiled
)
{
CHECK_INPUT(hashed_weights);
CHECK_INPUT(input);
CHECK_INPUT(random_numbers);
if (tiled) {
return rz_linear_forward_cuda_tiled(hashed_weights, input, random_numbers, input_dim, output_dim, chunk_size);
} else {
return rz_linear_forward_cuda(hashed_weights, input, random_numbers, input_dim, output_dim, chunk_size);
}
}
std::tuple<torch::Tensor, torch::Tensor> rz_linear_backward_cuda (
const torch::Tensor& out_grad,
const torch::Tensor& hashed_weights, // 1 x n
const torch::Tensor& input, // b x d1
const torch::Tensor& random_numbers,
int input_dim,
int output_dim,
int chunk_size,
bool tiled
)
{
// we have to return two grad - w.r.t input and w.r.t hashed_weights
auto input_grad = at::zeros({input.size(0), input.size(1)}, input.options());
auto weight_grad = at::zeros({hashed_weights.size(0)}, input.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.device().index());
//input_grad TODO cannot take advantage of chunk in robe-z here.
int x_max = input.size(0);
int y_max = input_dim;
dim3 block = dim3(1, MAX_BLOCK_SIZE, 1);
if (y_max < MAX_BLOCK_SIZE) {
block = dim3(1, y_max, 1);
}
dim3 grid = dim3(MAX_GRID_SIZE, 1, 1);
if ( x_max < MAX_GRID_SIZE) {
grid = dim3(x_max, 1, 1);
}
AT_DISPATCH_FLOATING_TYPES(hashed_weights.type(), "rz_linear_backward_cuda_input", ([&] {
rz_linear_backward_cuda_kernel_input<scalar_t><<<grid, block, 0, stream>>>(
hashed_weights.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
input.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
out_grad.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
input_grad.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
random_numbers.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
input.size(0),
input_dim,
output_dim,
chunk_size,
hashed_weights.size(0),
tiled
);
}));
//weight_grad TODO cannot take advantage of chunk in robe-z here.
x_max = input_dim;
y_max = output_dim;
block = dim3(1, MAX_BLOCK_SIZE, 1);
if (y_max < MAX_BLOCK_SIZE) {
block = dim3(1, y_max, 1);
}
grid = dim3(MAX_GRID_SIZE, 1, 1);
if ( x_max < MAX_GRID_SIZE) {
grid = dim3(x_max, 1, 1);
}
AT_DISPATCH_FLOATING_TYPES(hashed_weights.type(), "rz_linear_backward_cuda_weight", ([&] {
rz_linear_backward_cuda_kernel_weight<scalar_t><<<grid, block, 0, stream>>>(
hashed_weights.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
input.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
out_grad.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
weight_grad.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
random_numbers.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
input.size(0),
input_dim,
output_dim,
chunk_size,
hashed_weights.size(0),
tiled
);
}));
return std::tuple<torch::Tensor, torch::Tensor>(input_grad, weight_grad);
}
std::tuple<torch::Tensor, torch::Tensor> rz_linear_backward(
const torch::Tensor& out_grad,
const torch::Tensor& hashed_weights,
const torch::Tensor& input,
const torch::Tensor& random_numbers,
int input_dim,
int output_dim,
int chunk_size,
bool tiled
)
{
CHECK_INPUT(hashed_weights);
CHECK_INPUT(input);
CHECK_INPUT(out_grad);
CHECK_INPUT(random_numbers);
return rz_linear_backward_cuda(out_grad, hashed_weights, input, random_numbers, input_dim, output_dim, chunk_size, tiled);
}
__global__ void rz_linear_idx(torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> random_numbers,
torch::PackedTensorAccessor32<int64_t, 2, torch::RestrictPtrTraits> IDX,
int input_dim,
int output_dim,
int chunk_size,
int weight_size,
bool tiled) {
int64_t loc;
for(int ty = threadIdx.x; ty < output_dim; ty += blockDim.x) {
for (int bx = blockIdx.x; bx < input_dim; bx += gridDim.x) {
if (tiled) {
loc = location_tiled(bx, ty, random_numbers, weight_size);
} else {
loc = location(bx, ty, chunk_size, random_numbers, weight_size);
}
IDX[bx][ty] = loc;
}
}
}
torch::Tensor rz_get_idx(torch::Tensor& random_numbers, int input_dim, int output_dim, int chunk_size, int weight_size, bool tiled) {
CHECK_INPUT(random_numbers);
auto IDX = at::zeros({input_dim, output_dim}, random_numbers.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(random_numbers.device().index());
int block = MAX_BLOCK_SIZE;
int grid = MAX_GRID_SIZE;
if (block > output_dim) {
block = output_dim;
}
if (grid > input_dim) {
grid = input_dim;
}
rz_linear_idx<<<grid, block, 0, stream>>>(
random_numbers.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
IDX.packed_accessor32<int64_t, 2, torch::RestrictPtrTraits>(),
input_dim,
output_dim,
chunk_size,
weight_size,
tiled
);
return IDX;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &rz_linear_forward, "robe_z_mm (CUDA)");
m.def("backward", &rz_linear_backward, "robe_z_mm (CUDA)");
m.def("get_idx", &rz_get_idx, "robe_z_mm (CUDA) ");
}