| | """
|
| | Command processor for handling GPU commands including thread management.
|
| | """
|
| | from typing import Dict, Any, List
|
| | from threading import Lock
|
| | import time
|
| |
|
| | class CommandProcessor:
|
| | def __init__(self, hal, memory_manager):
|
| | self.hal = hal
|
| | self.memory_manager = memory_manager
|
| | self.command_queue = []
|
| | self.queue_lock = Lock()
|
| |
|
| | def add_command(self, command_type: str, **kwargs):
|
| | """Add a command to the queue"""
|
| | with self.queue_lock:
|
| | self.command_queue.append({
|
| | "type": command_type,
|
| | "params": kwargs,
|
| | "timestamp": time.time_ns()
|
| | })
|
| |
|
| | def clear_commands(self):
|
| | """Clear all pending commands"""
|
| | with self.queue_lock:
|
| | self.command_queue.clear()
|
| |
|
| | def submit_commands(self, chip_id: int = 0):
|
| | """Submit and execute all queued commands"""
|
| | results = []
|
| | with self.queue_lock:
|
| | for cmd in self.command_queue:
|
| | if cmd["type"] == "execute_kernel":
|
| | result = self._execute_kernel_command(cmd["params"])
|
| | elif cmd["type"] == "block_barrier":
|
| | result = self._handle_block_barrier(cmd["params"])
|
| | elif cmd["type"] == "core_barrier":
|
| | result = self._handle_core_barrier(cmd["params"])
|
| | elif cmd["type"] == "matmul":
|
| | result = self._handle_matmul(cmd["params"])
|
| | elif cmd["type"] == "global_barrier":
|
| | result = self._handle_global_barrier(cmd["params"])
|
| | else:
|
| | result = {"status": "error", "message": f"Unknown command type: {cmd['type']}"}
|
| |
|
| | results.append(result)
|
| |
|
| | self.command_queue.clear()
|
| | return results
|
| |
|
| | def _execute_kernel_command(self, params: Dict[str, Any]):
|
| | """Execute a kernel across thread blocks"""
|
| | try:
|
| | chip_id = params["chip_id"]
|
| | sm_id = params["sm_id"]
|
| | core_id = params["core_id"]
|
| | thread_config = params["thread_block_config"]
|
| | kernel_func = params["kernel_func"]
|
| | args = params.get("args", [])
|
| | kwargs = params.get("kwargs", {})
|
| |
|
| |
|
| | blocks_per_grid = (
|
| | thread_config['grid_dim'][0] *
|
| | thread_config['grid_dim'][1] *
|
| | thread_config['grid_dim'][2]
|
| | )
|
| |
|
| | threads_per_block = (
|
| | thread_config['block_dim'][0] *
|
| | thread_config['block_dim'][1] *
|
| | thread_config['block_dim'][2]
|
| | )
|
| |
|
| |
|
| | blocks = []
|
| | for block_idx in range(blocks_per_grid):
|
| | block = {
|
| | 'id': block_idx,
|
| | 'threads': threads_per_block,
|
| | 'shared_memory_size': thread_config['shared_memory_size'],
|
| | 'results': []
|
| | }
|
| | blocks.append(block)
|
| |
|
| |
|
| | for block in blocks:
|
| |
|
| | for thread_idx in range(block['threads']):
|
| | thread_id = block['id'] * block['threads'] + thread_idx
|
| | try:
|
| | result = kernel_func(
|
| | thread_id=thread_id,
|
| | block_id=block['id'],
|
| | *args,
|
| | **kwargs
|
| | )
|
| | block['results'].append({
|
| | 'thread_id': thread_id,
|
| | 'result': result,
|
| | 'status': 'success'
|
| | })
|
| | except Exception as e:
|
| | block['results'].append({
|
| | 'thread_id': thread_id,
|
| | 'error': str(e),
|
| | 'status': 'error'
|
| | })
|
| |
|
| | return {
|
| | 'status': 'success',
|
| | 'blocks_executed': len(blocks),
|
| | 'total_threads': blocks_per_grid * threads_per_block,
|
| | 'results': [b['results'] for b in blocks]
|
| | }
|
| |
|
| | except Exception as e:
|
| | return {
|
| | 'status': 'error',
|
| | 'message': f'Kernel execution failed: {str(e)}'
|
| | }
|
| |
|
| | def _handle_block_barrier(self, params: Dict[str, Any]):
|
| | """Handle block-level thread synchronization"""
|
| | try:
|
| | chip_id = params["chip_id"]
|
| | sm_id = params["sm_id"]
|
| | core_id = params["core_id"]
|
| | block_id = params["block_id"]
|
| |
|
| |
|
| | self.hal.block_barrier(chip_id, sm_id, core_id, block_id)
|
| |
|
| | return {
|
| | 'status': 'success',
|
| | 'message': f'Block barrier completed for block {block_id}'
|
| | }
|
| | except Exception as e:
|
| | return {
|
| | 'status': 'error',
|
| | 'message': f'Block barrier failed: {str(e)}'
|
| | }
|
| |
|
| | def _handle_core_barrier(self, params: Dict[str, Any]):
|
| | """Handle core-level thread synchronization"""
|
| | try:
|
| | chip_id = params["chip_id"]
|
| | sm_id = params["sm_id"]
|
| | core_id = params["core_id"]
|
| |
|
| |
|
| | self.hal.core_barrier(chip_id, sm_id, core_id)
|
| |
|
| | return {
|
| | 'status': 'success',
|
| | 'message': f'Core barrier completed for core {core_id}'
|
| | }
|
| | except Exception as e:
|
| | return {
|
| | 'status': 'error',
|
| | 'message': f'Core barrier failed: {str(e)}'
|
| | }
|
| |
|
| | def _handle_matmul(self, params: Dict[str, Any]):
|
| | """Handle matrix multiplication command"""
|
| | try:
|
| | return self.hal.matmul(
|
| | params["chip_id"],
|
| | params["sm_id"],
|
| | params["A"],
|
| | params["B"]
|
| | )
|
| | except Exception as e:
|
| | return {
|
| | 'status': 'error',
|
| | 'message': f'Matrix multiplication failed: {str(e)}'
|
| | }
|
| |
|
| | def _handle_global_barrier(self, params: Dict[str, Any]):
|
| | """Handle global synchronization across all threads"""
|
| | try:
|
| | chip_id = params["chip_id"]
|
| |
|
| |
|
| | self.hal.global_barrier(chip_id)
|
| |
|
| | return {
|
| | 'status': 'success',
|
| | 'message': f'Global barrier completed for chip {chip_id}'
|
| | }
|
| | except Exception as e:
|
| | return {
|
| | 'status': 'error',
|
| | 'message': f'Global barrier failed: {str(e)}'
|
| | }
|
| |
|