commit
8e0248b2ec
7 changed files with 163 additions and 0 deletions
@ -0,0 +1,4 @@ |
|||
target |
|||
.venv |
|||
uv.lock |
|||
Cargo.lock |
|||
@ -0,0 +1,7 @@ |
|||
[package] |
|||
name = "rust-triton" |
|||
version = "0.1.0" |
|||
edition = "2021" |
|||
|
|||
[dependencies] |
|||
cust = "0.3.2" |
|||
@ -0,0 +1,22 @@ |
|||
# Installation |
|||
|
|||
You need to install `triton` and `torch` on your system or create a venv with: |
|||
```bash |
|||
# With uv |
|||
uv sync |
|||
|
|||
# Or manually |
|||
python -m venv .venv |
|||
source .venv/bin/activate |
|||
pip install build |
|||
pip install -e . |
|||
```` |
|||
|
|||
Then, set the Rust nightly toolchain and build: |
|||
|
|||
```bash |
|||
rustup toolchain install nightly |
|||
cargo run |
|||
``` |
|||
|
|||
The cargo build hooks will take care of compiling the PTX files with Python. |
|||
@ -0,0 +1,30 @@ |
|||
use std::path::Path; |
|||
use std::process::Command; |
|||
|
|||
fn main() { |
|||
// Path to your Python script
|
|||
let python_script = "src/kernel.py"; |
|||
|
|||
let python_executable = if Path::new(".venv").exists() { |
|||
if cfg!(windows) { |
|||
".venv\\Scripts\\python.exe" |
|||
} else { |
|||
".venv/bin/python" |
|||
} |
|||
} else { |
|||
"python" |
|||
}; |
|||
|
|||
// Run the Python script
|
|||
let status = Command::new(python_executable) |
|||
.arg(python_script) |
|||
.status() |
|||
.expect("Failed to execute Python script"); |
|||
|
|||
if !status.success() { |
|||
panic!("CUDA kernel compilation failed"); |
|||
} |
|||
|
|||
// Rerun the PTX generation script if it has changed
|
|||
println!("cargo:rerun-if-changed=src/kernel.py"); |
|||
} |
|||
@ -0,0 +1,15 @@ |
|||
[project] |
|||
name = "rust-triton" |
|||
version = "0.1.0" |
|||
description = "Add your description here" |
|||
readme = "README.md" |
|||
requires-python = ">=3.12" |
|||
dependencies = [ |
|||
"triton>=3.0.0", |
|||
"setuptools>=75.1.0", |
|||
"torch>=2.4.1", |
|||
] |
|||
|
|||
[build-system] |
|||
requires = ["hatchling"] |
|||
build-backend = "hatchling.build" |
|||
@ -0,0 +1,33 @@ |
|||
import triton |
|||
import triton.language as tl |
|||
import torch |
|||
|
|||
|
|||
@triton.jit |
|||
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): |
|||
pid = tl.program_id(axis=0) |
|||
block_start = pid * BLOCK_SIZE |
|||
offsets = block_start + tl.arange(0, BLOCK_SIZE) |
|||
mask = offsets < n_elements |
|||
|
|||
x = tl.load(x_ptr + offsets, mask=mask) |
|||
y = tl.load(y_ptr + offsets, mask=mask) |
|||
|
|||
output = x + y |
|||
|
|||
tl.store(output_ptr + offsets, output, mask=mask) |
|||
|
|||
|
|||
N_ELEMENTS = 1024 |
|||
x = torch.zeros(N_ELEMENTS).cuda() |
|||
y = torch.zeros(N_ELEMENTS).cuda() |
|||
output = torch.zeros(N_ELEMENTS).cuda() |
|||
|
|||
|
|||
def grid(meta): |
|||
return (triton.cdiv(N_ELEMENTS, meta["BLOCK_SIZE"]),) |
|||
|
|||
|
|||
add_kernel[grid](x, y, output, N_ELEMENTS, BLOCK_SIZE=256) |
|||
with open("add_kernel.ptx", "w") as a: |
|||
print(list(add_kernel.cache[0].values())[0].asm["ptx"], file=a) |
|||
@ -0,0 +1,52 @@ |
|||
use cust::prelude::*; |
|||
use std::error::Error; |
|||
|
|||
const SIZE: usize = 4096; |
|||
|
|||
static PTX: &str = include_str!("../add_kernel.ptx"); |
|||
|
|||
fn run() -> Result<(), Box<dyn Error>> { |
|||
let _ctx = cust::quick_init().expect("Could not create CUDA context"); |
|||
|
|||
let x: [f32; SIZE] = std::array::from_fn(|i| i as f32 + 1.); |
|||
let y: [f32; SIZE] = std::array::from_fn(|i| i as f32 + 1.); |
|||
let o: [f32; SIZE] = [0.0; SIZE]; |
|||
|
|||
let module = Module::from_ptx(PTX, &[]).expect("Could not create module from PTX"); |
|||
|
|||
let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?; |
|||
|
|||
let x_d = x.as_slice().as_dbuf()?; |
|||
let y_d = y.as_slice().as_dbuf()?; |
|||
let o_d = o.as_slice().as_dbuf()?; |
|||
|
|||
let func = module |
|||
.get_function("add_kernel") |
|||
.expect("could not find the kernel!"); |
|||
|
|||
let block_size = cust::function::BlockSize { x: 128, y: 1, z: 1 }; |
|||
let grid_size = cust::function::GridSize { |
|||
x: SIZE as u32 / block_size.x, |
|||
y: 1, |
|||
z: 1, |
|||
}; |
|||
unsafe { |
|||
launch!(func<<<grid_size, block_size, 9216, stream>>>( |
|||
x_d.as_device_ptr(), |
|||
y_d.as_device_ptr(), |
|||
o_d.as_device_ptr(), |
|||
SIZE as i32, |
|||
))?; |
|||
} |
|||
stream.synchronize().expect("failed to sync"); |
|||
|
|||
let o = o_d.as_slice().as_host_vec()?; |
|||
|
|||
println!("o: {:?}", &o[..20]); |
|||
|
|||
Ok(()) |
|||
} |
|||
|
|||
fn main() { |
|||
run().expect("something went wrong"); |
|||
} |
|||
Loading…
Reference in new issue