commit
8e0248b2ec
7 changed files with 163 additions and 0 deletions
@ -0,0 +1,4 @@ |
|||||
|
target |
||||
|
.venv |
||||
|
uv.lock |
||||
|
Cargo.lock |
||||
@ -0,0 +1,7 @@ |
|||||
|
[package] |
||||
|
name = "rust-triton" |
||||
|
version = "0.1.0" |
||||
|
edition = "2021" |
||||
|
|
||||
|
[dependencies] |
||||
|
cust = "0.3.2" |
||||
@ -0,0 +1,22 @@ |
|||||
|
# Installation |
||||
|
|
||||
|
You need to install `triton` and `torch` on your system or create a venv with: |
||||
|
```bash |
||||
|
# With uv |
||||
|
uv sync |
||||
|
|
||||
|
# Or manually |
||||
|
python -m venv .venv |
||||
|
source .venv/bin/activate |
||||
|
pip install build |
||||
|
pip install -e . |
||||
|
```` |
||||
|
|
||||
|
Then, set the Rust nightly toolchain and build: |
||||
|
|
||||
|
```bash |
||||
|
rustup toolchain install nightly |
||||
|
cargo run |
||||
|
``` |
||||
|
|
||||
|
The cargo build hooks will take care of compiling the PTX files with Python. |
||||
@ -0,0 +1,30 @@ |
|||||
|
use std::path::Path; |
||||
|
use std::process::Command; |
||||
|
|
||||
|
fn main() { |
||||
|
// Path to your Python script
|
||||
|
let python_script = "src/kernel.py"; |
||||
|
|
||||
|
let python_executable = if Path::new(".venv").exists() { |
||||
|
if cfg!(windows) { |
||||
|
".venv\\Scripts\\python.exe" |
||||
|
} else { |
||||
|
".venv/bin/python" |
||||
|
} |
||||
|
} else { |
||||
|
"python" |
||||
|
}; |
||||
|
|
||||
|
// Run the Python script
|
||||
|
let status = Command::new(python_executable) |
||||
|
.arg(python_script) |
||||
|
.status() |
||||
|
.expect("Failed to execute Python script"); |
||||
|
|
||||
|
if !status.success() { |
||||
|
panic!("CUDA kernel compilation failed"); |
||||
|
} |
||||
|
|
||||
|
// Rerun the PTX generation script if it has changed
|
||||
|
println!("cargo:rerun-if-changed=src/kernel.py"); |
||||
|
} |
||||
@ -0,0 +1,15 @@ |
|||||
|
[project] |
||||
|
name = "rust-triton" |
||||
|
version = "0.1.0" |
||||
|
description = "Add your description here" |
||||
|
readme = "README.md" |
||||
|
requires-python = ">=3.12" |
||||
|
dependencies = [ |
||||
|
"triton>=3.0.0", |
||||
|
"setuptools>=75.1.0", |
||||
|
"torch>=2.4.1", |
||||
|
] |
||||
|
|
||||
|
[build-system] |
||||
|
requires = ["hatchling"] |
||||
|
build-backend = "hatchling.build" |
||||
@ -0,0 +1,33 @@ |
|||||
|
import triton |
||||
|
import triton.language as tl |
||||
|
import torch |
||||
|
|
||||
|
|
||||
|
@triton.jit |
||||
|
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): |
||||
|
pid = tl.program_id(axis=0) |
||||
|
block_start = pid * BLOCK_SIZE |
||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE) |
||||
|
mask = offsets < n_elements |
||||
|
|
||||
|
x = tl.load(x_ptr + offsets, mask=mask) |
||||
|
y = tl.load(y_ptr + offsets, mask=mask) |
||||
|
|
||||
|
output = x + y |
||||
|
|
||||
|
tl.store(output_ptr + offsets, output, mask=mask) |
||||
|
|
||||
|
|
||||
|
N_ELEMENTS = 1024 |
||||
|
x = torch.zeros(N_ELEMENTS).cuda() |
||||
|
y = torch.zeros(N_ELEMENTS).cuda() |
||||
|
output = torch.zeros(N_ELEMENTS).cuda() |
||||
|
|
||||
|
|
||||
|
def grid(meta): |
||||
|
return (triton.cdiv(N_ELEMENTS, meta["BLOCK_SIZE"]),) |
||||
|
|
||||
|
|
||||
|
add_kernel[grid](x, y, output, N_ELEMENTS, BLOCK_SIZE=256) |
||||
|
with open("add_kernel.ptx", "w") as a: |
||||
|
print(list(add_kernel.cache[0].values())[0].asm["ptx"], file=a) |
||||
@ -0,0 +1,52 @@ |
|||||
|
use cust::prelude::*; |
||||
|
use std::error::Error; |
||||
|
|
||||
|
const SIZE: usize = 4096; |
||||
|
|
||||
|
static PTX: &str = include_str!("../add_kernel.ptx"); |
||||
|
|
||||
|
fn run() -> Result<(), Box<dyn Error>> { |
||||
|
let _ctx = cust::quick_init().expect("Could not create CUDA context"); |
||||
|
|
||||
|
let x: [f32; SIZE] = std::array::from_fn(|i| i as f32 + 1.); |
||||
|
let y: [f32; SIZE] = std::array::from_fn(|i| i as f32 + 1.); |
||||
|
let o: [f32; SIZE] = [0.0; SIZE]; |
||||
|
|
||||
|
let module = Module::from_ptx(PTX, &[]).expect("Could not create module from PTX"); |
||||
|
|
||||
|
let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?; |
||||
|
|
||||
|
let x_d = x.as_slice().as_dbuf()?; |
||||
|
let y_d = y.as_slice().as_dbuf()?; |
||||
|
let o_d = o.as_slice().as_dbuf()?; |
||||
|
|
||||
|
let func = module |
||||
|
.get_function("add_kernel") |
||||
|
.expect("could not find the kernel!"); |
||||
|
|
||||
|
let block_size = cust::function::BlockSize { x: 128, y: 1, z: 1 }; |
||||
|
let grid_size = cust::function::GridSize { |
||||
|
x: SIZE as u32 / block_size.x, |
||||
|
y: 1, |
||||
|
z: 1, |
||||
|
}; |
||||
|
unsafe { |
||||
|
launch!(func<<<grid_size, block_size, 9216, stream>>>( |
||||
|
x_d.as_device_ptr(), |
||||
|
y_d.as_device_ptr(), |
||||
|
o_d.as_device_ptr(), |
||||
|
SIZE as i32, |
||||
|
))?; |
||||
|
} |
||||
|
stream.synchronize().expect("failed to sync"); |
||||
|
|
||||
|
let o = o_d.as_slice().as_host_vec()?; |
||||
|
|
||||
|
println!("o: {:?}", &o[..20]); |
||||
|
|
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
fn main() { |
||||
|
run().expect("something went wrong"); |
||||
|
} |
||||
Loading…
Reference in new issue