跳转至

函数系统

Genesis的函数系统为自动微分提供基础,定义了操作在前向传播中如何执行以及在反向传播中如何计算梯度。

📋 概述

函数系统围绕Function基类构建,封装了: - 前向计算逻辑 - 反向梯度计算 - 用于存储中间值的上下文管理 - 与自动微分引擎的集成

🏗️ 架构

graph TB
    subgraph "函数系统"
        A[Function基类] --> B[apply()方法]
        A --> C[forward()方法]
        A --> D[backward()方法]
        E[Context] --> F[save_for_backward()]
        E --> G[saved_tensors]
    end

    subgraph "自动微分集成"
        B --> H[计算图]
        H --> I[梯度流]
        I --> J[反向传播]
    end

    subgraph "内置函数"
        K[AddFunction] --> A
        L[MulFunction] --> A
        M[MatMulFunction] --> A
        N[ReluFunction] --> A
    end

    style A fill:#e1f5fe
    style E fill:#f3e5f5
    style H fill:#e8f5e8

🎯 核心概念

Function基类

Function类为所有操作提供接口:

Python
class Function:
    """所有自动微分函数的基类。"""

    @staticmethod
    def apply(*args):
        """应用具有自动微分支持的函数。"""
        ctx = Context()

        # 前向传播
        result = cls.forward(ctx, *args)

        # 如果任何输入需要梯度,设置反向传播
        if any(tensor.requires_grad for tensor in args if isinstance(tensor, Tensor)):
            result.set_creator(ctx, cls.backward)

        return result

    @staticmethod
    def forward(ctx, *args):
        """计算前向传播。必须由子类实现。"""
        raise NotImplementedError

    @staticmethod
    def backward(ctx, *grad_outputs):
        """计算反向传播。必须由子类实现。"""
        raise NotImplementedError

上下文管理

Context类管理反向计算所需的信息:

Python
class Context:
    """用于存储反向传播期间所需信息的上下文。"""

    def __init__(self):
        self.saved_tensors = []
        self.saved_variables = {}

    def save_for_backward(self, *tensors):
        """保存张量以供反向传播使用。"""
        self.saved_tensors.extend(tensors)

    def save_variable(self, name, value):
        """保存变量以供反向传播使用。"""
        self.saved_variables[name] = value

💻 实现示例

基本算术函数

Python
class AddFunction(Function):
    """支持梯度的加法函数。"""

    @staticmethod
    def forward(ctx, a, b):
        """前向传播:计算a + b。"""
        # 加法不需要保存输入
        return genesis.ops.add(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        """反向传播:梯度不变流动。"""
        return grad_output, grad_output

# 使用
add = AddFunction.apply
c = add(a, b)  # 等价于支持自动微分的 a + b

矩阵乘法函数

Python
class MatMulFunction(Function):
    """支持梯度的矩阵乘法。"""

    @staticmethod
    def forward(ctx, a, b):
        """前向传播:计算 a @ b。"""
        ctx.save_for_backward(a, b)
        return genesis.ops.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        """反向传播:使用链式法则计算梯度。"""
        a, b = ctx.saved_tensors

        grad_a = genesis.ops.matmul(grad_output, b.transpose(-2, -1))
        grad_b = genesis.ops.matmul(a.transpose(-2, -1), grad_output)

        return grad_a, grad_b

# 使用
matmul = MatMulFunction.apply
c = matmul(a, b)  # 等价于支持自动微分的 a @ b

带上下文的激活函数

Python
class ReluFunction(Function):
    """支持梯度的ReLU激活。"""

    @staticmethod
    def forward(ctx, input):
        """前向传播:计算 max(0, input)。"""
        output = genesis.ops.maximum(input, 0)
        ctx.save_for_backward(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """反向传播:负输入的梯度为0。"""
        input, = ctx.saved_tensors
        mask = input > 0
        return grad_output * mask

# 使用
relu = ReluFunction.apply
activated = relu(x)

🚀 高级特性

原地操作

Python
class AddInplaceFunction(Function):
    """原地加法函数。"""

    @staticmethod
    def forward(ctx, a, b):
        """前向传播:原地修改a。"""
        ctx.save_variable('original_a', a.clone())
        a.add_(b)
        return a

    @staticmethod
    def backward(ctx, grad_output):
        """原地操作的反向传播。"""
        return grad_output, grad_output

多输出函数

Python
class SplitFunction(Function):
    """返回多个输出的函数。"""

    @staticmethod
    def forward(ctx, input, split_sizes):
        """将输入张量分割成多个部分。"""
        ctx.save_variable('split_sizes', split_sizes)
        return genesis.ops.split(input, split_sizes)

    @staticmethod
    def backward(ctx, *grad_outputs):
        """从多个输出连接梯度。"""
        grad_input = genesis.ops.cat(grad_outputs, dim=0)
        return grad_input, None  # split_sizes没有梯度

自定义上下文变量

Python
class ScaleFunction(Function):
    """通过常数因子缩放张量。"""

    @staticmethod
    def forward(ctx, input, scale_factor):
        """通过常数因子缩放输入。"""
        ctx.save_variable('scale_factor', scale_factor)
        return input * scale_factor

    @staticmethod
    def backward(ctx, grad_output):
        """通过相同因子缩放梯度。"""
        scale_factor = ctx.saved_variables['scale_factor']
        return grad_output * scale_factor, None

🔧 与操作集成

向调度器注册函数

Python
# 向操作调度器注册函数
genesis.ops.register_function('add', AddFunction.apply)
genesis.ops.register_function('matmul', MatMulFunction.apply)
genesis.ops.register_function('relu', ReluFunction.apply)

# 现在操作自动使用注册的函数
x = genesis.tensor([1, 2, 3], requires_grad=True)
y = genesis.tensor([4, 5, 6], requires_grad=True)
z = x + y  # 自动使用AddFunction

自定义操作定义

Python
def custom_operation(input, param):
    """使用Function定义自定义操作。"""
    return CustomFunction.apply(input, param)

# 注册为操作
genesis.ops.register_operation('custom_op', custom_operation)

# 像任何其他操作一样使用
result = genesis.custom_op(tensor, param)

📊 性能考虑

内存效率

Python
class EfficientFunction(Function):
    """内存高效的函数实现。"""

    @staticmethod
    def forward(ctx, input):
        # 只保存反向所需的内容
        ctx.save_for_backward(input.detach())  # 分离以避免递归梯度

        # 高效计算结果
        result = efficient_computation(input)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        # 高效计算梯度
        return efficient_gradient_computation(input, grad_output)

数值稳定性

Python
class StableFunction(Function):
    """数值稳定的函数实现。"""

    @staticmethod
    def forward(ctx, input):
        # 使用数值稳定的计算
        output = stable_computation(input)
        ctx.save_for_backward(input, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, output = ctx.saved_tensors
        # 使用稳定的梯度计算
        return stable_gradient(input, output, grad_output)

🔍 调试和测试

函数测试

Python
def test_function_gradients():
    """测试函数梯度计算。"""
    x = genesis.tensor([1.0, 2.0, 3.0], requires_grad=True)

    # 测试前向传播
    y = CustomFunction.apply(x)

    # 测试反向传播
    y.backward(genesis.tensor([1.0, 1.0, 1.0]))

    # 检查梯度
    assert x.grad is not None
    print(f"梯度:{x.grad}")

# 数值梯度检查
def numerical_gradient_check(func, input, eps=1e-5):
    """使用数值微分检查梯度。"""
    # 数值梯度检查的实现
    pass

调试上下文

Python
class DebugFunction(Function):
    """带调试信息的函数。"""

    @staticmethod
    def forward(ctx, input):
        print(f"前向:输入形状 = {input.shape}")
        ctx.save_for_backward(input)
        result = computation(input)
        print(f"前向:输出形状 = {result.shape}")
        return result

    @staticmethod
    def backward(ctx, grad_output):
        print(f"反向:grad_output形状 = {grad_output.shape}")
        input, = ctx.saved_tensors
        grad_input = gradient_computation(input, grad_output)
        print(f"反向:grad_input形状 = {grad_input.shape}")
        return grad_input

🔗 参见