操作分发器¶
操作分发器是Genesis v2.0的核心组件,负责将张量操作路由到适当的后端实现。
📋 概述¶
分发器提供: - 集中的操作路由 - 自动后端选择 - 操作注册和管理 - 性能优化机会
🏗️ 架构¶
graph TB
subgraph "分发器组件"
A[OperationDispatcher] --> B[操作注册表]
A --> C[设备推断]
A --> D[后端选择器]
A --> E[执行引擎]
end
subgraph "操作流程"
F[用户调用] --> G[分发器]
G --> H[设备检测]
H --> I[选择实现]
I --> J[执行操作]
J --> K[返回结果]
end
style A fill:#e1f5fe
style G fill:#f3e5f5 🎯 核心功能¶
操作分发器类¶
Python
class OperationDispatcher:
"""中央操作分发系统。"""
def __init__(self):
self._operations = {}
self._metadata = {}
self._cache = {}
def register(self, name, implementations):
"""注册新操作。"""
self._operations[name] = implementations
def dispatch(self, op_name, *args, **kwargs):
"""分发操作到后端。"""
# 1. 验证操作存在
if op_name not in self._operations:
raise ValueError(f"未知操作:{op_name}")
# 2. 推断设备
device = self._infer_device(args)
# 3. 选择实现
impl = self._select_implementation(op_name, device)
# 4. 执行操作
return impl(*args, **kwargs)
设备推断¶
Python
def _infer_device(self, args):
"""从参数推断目标设备。"""
devices = []
for arg in args:
if hasattr(arg, 'device'):
devices.append(arg.device)
if not devices:
# 使用默认设备
return genesis.get_default_device()
# 检查设备一致性
unique_devices = set(str(d) for d in devices)
if len(unique_devices) > 1:
# 设备提升规则
if 'cuda' in str(unique_devices):
return genesis.device('cuda')
else:
raise RuntimeError(f"设备冲突:{unique_devices}")
return devices[0]
💡 操作注册¶
基本注册¶
Python
# 注册简单操作
dispatcher = OperationDispatcher()
dispatcher.register('add', {
'cpu': cpu_add_impl,
'cuda': cuda_add_impl
})
# 使用注册的操作
result = dispatcher.dispatch('add', x, y)
带元数据的注册¶
Python
# 注册带额外信息的操作
dispatcher.register_with_metadata('matmul', {
'implementations': {
'cpu': cpu_matmul,
'cuda': cuda_matmul
},
'supports_autograd': True,
'memory_intensive': True,
'fusion_candidates': ['add', 'relu']
})
动态注册¶
Python
def register_dynamic_operation(name, generator):
"""动态生成操作实现。"""
def dynamic_dispatcher(*args, **kwargs):
# 基于输入动态生成实现
impl = generator(args, kwargs)
return impl(*args, **kwargs)
dispatcher.register(name, {
'cpu': dynamic_dispatcher,
'cuda': dynamic_dispatcher
})
🚀 优化策略¶
操作缓存¶
Python
class CachedDispatcher(OperationDispatcher):
"""带结果缓存的分发器。"""
def dispatch(self, op_name, *args, **kwargs):
# 生成缓存键
cache_key = self._generate_cache_key(op_name, args)
# 检查缓存
if cache_key in self._cache:
return self._cache[cache_key]
# 执行并缓存
result = super().dispatch(op_name, *args, **kwargs)
self._cache[cache_key] = result
return result
def _generate_cache_key(self, op_name, args):
"""为操作生成唯一缓存键。"""
# 基于操作和输入形状/类型的键
shapes = tuple(arg.shape for arg in args if hasattr(arg, 'shape'))
dtypes = tuple(arg.dtype for arg in args if hasattr(arg, 'dtype'))
return (op_name, shapes, dtypes)
操作融合¶
Python
class FusionDispatcher(OperationDispatcher):
"""支持操作融合的分发器。"""
def __init__(self):
super().__init__()
self._fusion_patterns = []
def register_fusion_pattern(self, pattern, fused_impl):
"""注册融合模式。"""
self._fusion_patterns.append({
'pattern': pattern,
'implementation': fused_impl
})
def dispatch_sequence(self, operations):
"""分发操作序列,可能进行融合。"""
# 检查融合机会
for fusion in self._fusion_patterns:
if self._matches_pattern(operations, fusion['pattern']):
return fusion['implementation'](*operations)
# 无融合,顺序执行
results = []
for op in operations:
results.append(self.dispatch(op.name, *op.args))
return results
批量分发¶
Python
def batch_dispatch(self, operations):
"""批量分发多个操作。"""
# 按设备分组操作
device_groups = {}
for op in operations:
device = self._infer_device(op.args)
if device not in device_groups:
device_groups[device] = []
device_groups[device].append(op)
# 并行执行每个设备组
results = {}
for device, ops in device_groups.items():
if device.is_cuda:
# GPU操作可以异步执行
stream = genesis.cuda.Stream()
with genesis.cuda.stream(stream):
for op in ops:
results[op] = self.dispatch(op.name, *op.args)
else:
# CPU操作顺序执行
for op in ops:
results[op] = self.dispatch(op.name, *op.args)
return results
🔧 配置选项¶
全局配置¶
Python
# 配置分发器行为
genesis.ops.dispatcher.set_config({
'enable_fusion': True,
'cache_size': 1000,
'profile_operations': False,
'strict_device_checking': True
})
操作特定配置¶
Python
# 为特定操作设置配置
genesis.ops.dispatcher.configure_operation('matmul', {
'use_cublas': True,
'transpose_threshold': 1024,
'block_size': 256
})
📊 性能监控¶
操作统计¶
Python
class ProfilingDispatcher(OperationDispatcher):
"""带性能分析的分发器。"""
def __init__(self):
super().__init__()
self._stats = {}
def dispatch(self, op_name, *args, **kwargs):
# 记录开始时间
start_time = time.perf_counter()
# 执行操作
result = super().dispatch(op_name, *args, **kwargs)
# 记录统计
elapsed = time.perf_counter() - start_time
if op_name not in self._stats:
self._stats[op_name] = {
'count': 0,
'total_time': 0,
'max_time': 0,
'min_time': float('inf')
}
stats = self._stats[op_name]
stats['count'] += 1
stats['total_time'] += elapsed
stats['max_time'] = max(stats['max_time'], elapsed)
stats['min_time'] = min(stats['min_time'], elapsed)
return result
def print_stats(self):
"""打印操作统计。"""
for op_name, stats in self._stats.items():
avg_time = stats['total_time'] / stats['count']
print(f"{op_name}:")
print(f" 调用次数:{stats['count']}")
print(f" 平均时间:{avg_time*1000:.3f} ms")
print(f" 最大时间:{stats['max_time']*1000:.3f} ms")
print(f" 最小时间:{stats['min_time']*1000:.3f} ms")
瓶颈检测¶
Python
def detect_bottlenecks(self):
"""检测性能瓶颈。"""
bottlenecks = []
for op_name, stats in self._stats.items():
avg_time = stats['total_time'] / stats['count']
# 检查慢操作
if avg_time > 0.1: # 100ms阈值
bottlenecks.append({
'operation': op_name,
'avg_time': avg_time,
'suggestion': '考虑优化或融合'
})
# 检查频繁操作
if stats['count'] > 1000:
bottlenecks.append({
'operation': op_name,
'count': stats['count'],
'suggestion': '考虑缓存结果'
})
return bottlenecks
🔍 调试功能¶
操作日志¶
Python
class DebugDispatcher(OperationDispatcher):
"""带调试日志的分发器。"""
def dispatch(self, op_name, *args, **kwargs):
# 记录输入
print(f"[DISPATCH] 操作:{op_name}")
for i, arg in enumerate(args):
if hasattr(arg, 'shape'):
print(f" 参数{i}:shape={arg.shape}, dtype={arg.dtype}")
# 执行操作
result = super().dispatch(op_name, *args, **kwargs)
# 记录输出
if hasattr(result, 'shape'):
print(f" 结果:shape={result.shape}, dtype={result.dtype}")
return result
验证模式¶
Python
def enable_validation_mode(self):
"""启用操作验证。"""
self._validation_enabled = True
def validated_dispatch(op_name, *args, **kwargs):
# 验证输入
self._validate_inputs(op_name, args)
# 执行操作
result = self._original_dispatch(op_name, *args, **kwargs)
# 验证输出
self._validate_output(op_name, result)
return result
self._original_dispatch = self.dispatch
self.dispatch = validated_dispatch