Browse Source

feat: sum two vecs in CUDA

master
Ramon Calvo 1 year ago
commit
8e0248b2ec
  1. 4
      .gitignore
  2. 7
      Cargo.toml
  3. 22
      README.md
  4. 30
      build.rs
  5. 15
      pyproject.toml
  6. 33
      src/kernel.py
  7. 52
      src/main.rs

4
.gitignore

@ -0,0 +1,4 @@
target
.venv
uv.lock
Cargo.lock

7
Cargo.toml

@ -0,0 +1,7 @@
[package]
name = "rust-triton"
version = "0.1.0"
edition = "2021"
[dependencies]
cust = "0.3.2"

22
README.md

@ -0,0 +1,22 @@
# Installation
You need to install `triton` and `torch` on your system or create a venv with:
```bash
# With uv
uv sync
# Or manually
python -m venv .venv
source .venv/bin/activate
pip install build
pip install -e .
````
Then, set the Rust nightly toolchain and build:
```bash
rustup toolchain install nightly
cargo run
```
The cargo build hooks will take care of compiling the PTX files with Python.

30
build.rs

@ -0,0 +1,30 @@
use std::path::Path;
use std::process::Command;
fn main() {
// Path to your Python script
let python_script = "src/kernel.py";
let python_executable = if Path::new(".venv").exists() {
if cfg!(windows) {
".venv\\Scripts\\python.exe"
} else {
".venv/bin/python"
}
} else {
"python"
};
// Run the Python script
let status = Command::new(python_executable)
.arg(python_script)
.status()
.expect("Failed to execute Python script");
if !status.success() {
panic!("CUDA kernel compilation failed");
}
// Rerun the PTX generation script if it has changed
println!("cargo:rerun-if-changed=src/kernel.py");
}

15
pyproject.toml

@ -0,0 +1,15 @@
[project]
name = "rust-triton"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"triton>=3.0.0",
"setuptools>=75.1.0",
"torch>=2.4.1",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

33
src/kernel.py

@ -0,0 +1,33 @@
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
N_ELEMENTS = 1024
x = torch.zeros(N_ELEMENTS).cuda()
y = torch.zeros(N_ELEMENTS).cuda()
output = torch.zeros(N_ELEMENTS).cuda()
def grid(meta):
return (triton.cdiv(N_ELEMENTS, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, N_ELEMENTS, BLOCK_SIZE=256)
with open("add_kernel.ptx", "w") as a:
print(list(add_kernel.cache[0].values())[0].asm["ptx"], file=a)

52
src/main.rs

@ -0,0 +1,52 @@
use cust::prelude::*;
use std::error::Error;
const SIZE: usize = 4096;
static PTX: &str = include_str!("../add_kernel.ptx");
fn run() -> Result<(), Box<dyn Error>> {
let _ctx = cust::quick_init().expect("Could not create CUDA context");
let x: [f32; SIZE] = std::array::from_fn(|i| i as f32 + 1.);
let y: [f32; SIZE] = std::array::from_fn(|i| i as f32 + 1.);
let o: [f32; SIZE] = [0.0; SIZE];
let module = Module::from_ptx(PTX, &[]).expect("Could not create module from PTX");
let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
let x_d = x.as_slice().as_dbuf()?;
let y_d = y.as_slice().as_dbuf()?;
let o_d = o.as_slice().as_dbuf()?;
let func = module
.get_function("add_kernel")
.expect("could not find the kernel!");
let block_size = cust::function::BlockSize { x: 128, y: 1, z: 1 };
let grid_size = cust::function::GridSize {
x: SIZE as u32 / block_size.x,
y: 1,
z: 1,
};
unsafe {
launch!(func<<<grid_size, block_size, 9216, stream>>>(
x_d.as_device_ptr(),
y_d.as_device_ptr(),
o_d.as_device_ptr(),
SIZE as i32,
))?;
}
stream.synchronize().expect("failed to sync");
let o = o_d.as_slice().as_host_vec()?;
println!("o: {:?}", &o[..20]);
Ok(())
}
fn main() {
run().expect("something went wrong");
}
Loading…
Cancel
Save