| | from dataclasses import dataclass
|
| | from typing import List, Dict, Any, Callable
|
| | import numpy as np
|
| | from threading import Lock
|
| |
|
| | @dataclass
|
| | class KernelConfig:
|
| | """Configuration for a CUDA-like kernel launch"""
|
| | block_dim: tuple[int, int, int]
|
| | grid_dim: tuple[int, int, int]
|
| | shared_memory_size: int = 0
|
| |
|
| | class ThreadIdx:
|
| | """Thread index within a block"""
|
| | def __init__(self, x: int, y: int, z: int):
|
| | self.x = x
|
| | self.y = y
|
| | self.z = z
|
| |
|
| | class BlockIdx:
|
| | """Block index within the grid"""
|
| | def __init__(self, x: int, y: int, z: int):
|
| | self.x = x
|
| | self.y = y
|
| | self.z = z
|
| |
|
| | class Warp:
|
| | """Represents a group of 32 threads that execute in lockstep"""
|
| | WARP_SIZE = 32
|
| |
|
| | def __init__(self, warp_id: int, threads: List[ThreadIdx]):
|
| | self.warp_id = warp_id
|
| | self.threads = threads
|
| | self.active_mask = (1 << len(threads)) - 1
|
| |
|
| | def synchronize(self):
|
| | """Synchronize all threads in the warp"""
|
| | pass
|
| |
|
| | def vote_all(self, predicate: bool) -> bool:
|
| | """Returns true if predicate is true for all active threads"""
|
| | return all(predicate for _ in range(len(self.threads)))
|
| |
|
| | def vote_any(self, predicate: bool) -> bool:
|
| | """Returns true if predicate is true for any active thread"""
|
| | return any(predicate for _ in range(len(self.threads)))
|
| |
|
| | class Block:
|
| | """Represents a thread block with shared memory"""
|
| | def __init__(self, block_idx: BlockIdx, dim: tuple[int, int, int], shared_mem_size: int):
|
| | self.block_idx = block_idx
|
| | self.dim = dim
|
| | self.shared_memory = SharedMemory(shared_mem_size)
|
| | self.warps: List[Warp] = []
|
| | self._create_warps()
|
| |
|
| | def _create_warps(self):
|
| | """Organize threads into warps"""
|
| | threads = []
|
| | total_threads = self.dim[0] * self.dim[1] * self.dim[2]
|
| |
|
| | for idx in range(total_threads):
|
| |
|
| | z = idx // (self.dim[0] * self.dim[1])
|
| | y = (idx % (self.dim[0] * self.dim[1])) // self.dim[0]
|
| | x = idx % self.dim[0]
|
| | threads.append(ThreadIdx(x, y, z))
|
| |
|
| | if len(threads) == Warp.WARP_SIZE or idx == total_threads - 1:
|
| | self.warps.append(Warp(len(self.warps), threads))
|
| | threads = []
|
| |
|
| | def synchronize(self):
|
| | """Synchronize all threads in the block"""
|
| | for warp in self.warps:
|
| | warp.synchronize()
|
| |
|
| | class SharedMemory:
|
| | """Represents shared memory accessible by all threads in a block"""
|
| | def __init__(self, size_bytes: int):
|
| | self.size = size_bytes
|
| | self.data = bytearray(size_bytes)
|
| | self.lock = Lock()
|
| |
|
| | def read(self, offset: int, size: int) -> bytearray:
|
| | with self.lock:
|
| | return self.data[offset:offset + size]
|
| |
|
| | def write(self, offset: int, data: bytearray):
|
| | with self.lock:
|
| | self.data[offset:offset + len(data)] = data
|
| |
|
| | class KernelFunction:
|
| | """Wrapper for a kernel function"""
|
| | def __init__(self, func: Callable):
|
| | self.func = func
|
| | self.shared_memory_size = 0
|
| |
|
| | def configure(self, shared_memory_size: int = 0):
|
| | """Configure kernel properties"""
|
| | self.shared_memory_size = shared_memory_size
|
| | return self
|
| |
|
| | def __call__(self, *args, **kwargs):
|
| | """Execute the kernel function"""
|
| | return self.func(*args, **kwargs)
|
| |
|
| | def launch_kernel(kernel_func: KernelFunction, config: KernelConfig, *args):
|
| | """Launch a kernel with the specified configuration"""
|
| | total_blocks = config.grid_dim[0] * config.grid_dim[1] * config.grid_dim[2]
|
| |
|
| |
|
| | blocks = []
|
| | for block_idx in range(total_blocks):
|
| |
|
| | bz = block_idx // (config.grid_dim[0] * config.grid_dim[1])
|
| | by = (block_idx % (config.grid_dim[0] * config.grid_dim[1])) // config.grid_dim[0]
|
| | bx = block_idx % config.grid_dim[0]
|
| |
|
| | block = Block(
|
| | BlockIdx(bx, by, bz),
|
| | config.block_dim,
|
| | config.shared_memory_size
|
| | )
|
| | blocks.append(block)
|
| |
|
| |
|
| | for block in blocks:
|
| | for warp in block.warps:
|
| | for thread in warp.threads:
|
| | kernel_func(block, thread, *args)
|
| |
|
| | def kernel(func: Callable) -> KernelFunction:
|
| | """Decorator to mark a function as a kernel"""
|
| | return KernelFunction(func)
|
| |
|