OpenClaw 简化实现

openclaw openclaw中文博客 2

OpenClaw是一个基于自然语言理解的信息提取模型,这里提供一个简洁的PyTorch实现版本,包含核心架构和基本功能。

OpenClaw 简化实现-第1张图片-OpenClaw 中文版 - 真正能做事的 AI

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
class OpenClawSimplified(nn.Module):
    """
    OpenClaw简化版模型
    用于信息提取任务,结合自然语言查询和文档内容
    """
    def __init__(self, bert_model_name='bert-base-uncased', hidden_size=768, num_labels=3):
        """
        初始化模型
        参数:
            bert_model_name: BERT模型名称
            hidden_size: 隐藏层维度
            num_labels: 标签数量 (BIO标注)
        """
        super(OpenClawSimplified, self).__init__()
        # BERT编码器
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        # 查询注意力机制
        self.query_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        # 条件层归一化
        self.conditional_layer_norm = ConditionalLayerNorm(hidden_size)
        # 跨度预测头
        self.span_start_predictor = nn.Linear(hidden_size, 1)
        self.span_end_predictor = nn.Linear(hidden_size, 1)
        # 序列标注头
        self.sequence_labeler = nn.Linear(hidden_size, num_labels)
        # 融合层
        self.fusion_layer = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
    def forward(self, input_ids, attention_mask, query_input_ids, query_attention_mask):
        """
        前向传播
        参数:
            input_ids: 文档token IDs [batch_size, seq_len]
            attention_mask: 文档注意力掩码
            query_input_ids: 查询token IDs [batch_size, query_len]
            query_attention_mask: 查询注意力掩码
        返回:
            span_start_logits: 跨度开始位置分数
            span_end_logits: 跨度结束位置分数
            sequence_logits: 序列标注分数
        """
        batch_size = input_ids.size(0)
        # 编码文档
        document_outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        document_embeddings = document_outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        # 编码查询
        query_outputs = self.bert(
            input_ids=query_input_ids,
            attention_mask=query_attention_mask
        )
        query_embeddings = query_outputs.last_hidden_state  # [batch_size, query_len, hidden_size]
        # 查询感知的文档表示
        # 使用查询注意力机制
        query_aware_doc, _ = self.query_attention(
            query=document_embeddings,
            key=query_embeddings,
            value=query_embeddings,
            key_padding_mask=~query_attention_mask.bool() if query_attention_mask is not None else None
        )
        # 条件层归一化
        normalized_doc = self.conditional_layer_norm(document_embeddings, query_aware_doc)
        # 融合文档原始表示和查询感知表示
        fused_representation = self.fusion_layer(
            torch.cat([document_embeddings, normalized_doc], dim=-1)
        )
        # 跨度预测
        span_start_logits = self.span_start_predictor(fused_representation).squeeze(-1)
        span_end_logits = self.span_end_predictor(fused_representation).squeeze(-1)
        # 序列标注
        sequence_logits = self.sequence_labeler(fused_representation)
        return span_start_logits, span_end_logits, sequence_logits
    def extract_spans(self, input_ids, attention_mask, query_input_ids, query_attention_mask, threshold=0.5):
        """
        提取文本跨度
        参数:
            threshold: 预测阈值
        """
        span_start_logits, span_end_logits, _ = self.forward(
            input_ids, attention_mask, query_input_ids, query_attention_mask
        )
        # 应用sigmoid获取概率
        start_probs = torch.sigmoid(span_start_logits)
        end_probs = torch.sigmoid(span_end_logits)
        # 提取跨度
        batch_spans = []
        for i in range(start_probs.size(0)):
            spans = []
            start_positions = torch.where(start_probs[i] > threshold)[0].tolist()
            end_positions = torch.where(end_probs[i] > threshold)[0].tolist()
            # 简单匹配开始和结束位置
            for start in start_positions:
                # 寻找最近的结束位置
                possible_ends = [end for end in end_positions if end >= start]
                if possible_ends:
                    end = min(possible_ends)
                    spans.append((start.item(), end.item()))
            batch_spans.append(spans)
        return batch_spans
class ConditionalLayerNorm(nn.Module):
    """
    条件层归一化
    根据条件信息调整归一化
    """
    def __init__(self, hidden_size, eps=1e-12):
        super(ConditionalLayerNorm, self).__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        # 可学习的缩放和偏移参数
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        # 条件变换参数
        self.condition_weight = nn.Linear(hidden_size, hidden_size)
        self.condition_bias = nn.Linear(hidden_size, hidden_size)
    def forward(self, x, condition):
        """
        参数:
            x: 输入张量 [batch_size, seq_len, hidden_size]
            condition: 条件张量 [batch_size, seq_len, hidden_size]
        """
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        # 标准归一化
        normalized = (x - mean) / (std + self.eps)
        # 应用条件变换
        condition_scale = self.condition_weight(condition)
        condition_shift = self.condition_bias(condition)
        # 条件归一化
        output = normalized * self.weight + self.bias
        output = output * condition_scale + condition_shift
        return output
class OpenClawTrainer:
    """
    OpenClaw训练器
    """
    def __init__(self, model, learning_rate=2e-5):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        # 损失函数
        self.span_loss_fn = nn.BCEWithLogitsLoss()
        self.sequence_loss_fn = nn.CrossEntropyLoss()
    def train_step(self, batch):
        """
        单步训练
        """
        self.model.train()
        self.optimizer.zero_grad()
        # 解包批次数据
        (input_ids, attention_mask, 
         query_input_ids, query_attention_mask,
         span_start_labels, span_end_labels, sequence_labels) = batch
        # 前向传播
        span_start_logits, span_end_logits, sequence_logits = self.model(
            input_ids, attention_mask,
            query_input_ids, query_attention_mask
        )
        # 计算损失
        span_start_loss = self.span_loss_fn(
            span_start_logits, span_start_labels.float()
        )
        span_end_loss = self.span_loss_fn(
            span_end_logits, span_end_labels.float()
        )
        # 序列标注损失
        batch_size, seq_len, num_labels = sequence_logits.shape
        sequence_loss = self.sequence_loss_fn(
            sequence_logits.view(-1, num_labels),
            sequence_labels.view(-1)
        )
        # 总损失
        total_loss = span_start_loss + span_end_loss + sequence_loss
        # 反向传播
        total_loss.backward()
        self.optimizer.step()
        return {
            'total_loss': total_loss.item(),
            'span_start_loss': span_start_loss.item(),
            'span_end_loss': span_end_loss.item(),
            'sequence_loss': sequence_loss.item()
        }
    def evaluate(self, dataloader):
        """
        评估模型
        """
        self.model.eval()
        total_loss = 0
        total_samples = 0
        with torch.no_grad():
            for batch in dataloader:
                (input_ids, attention_mask, 
                 query_input_ids, query_attention_mask,
                 span_start_labels, span_end_labels, sequence_labels) = batch
                # 前向传播
                span_start_logits, span_end_logits, sequence_logits = self.model(
                    input_ids, attention_mask,
                    query_input_ids, query_attention_mask
                )
                # 计算损失
                span_start_loss = self.span_loss_fn(
                    span_start_logits, span_start_labels.float()
                )
                span_end_loss = self.span_loss_fn(
                    span_end_logits, span_end_labels.float()
                )
                batch_size, seq_len, num_labels = sequence_logits.shape
                sequence_loss = self.sequence_loss_fn(
                    sequence_logits.view(-1, num_labels),
                    sequence_labels.view(-1)
                )
                total_loss += (span_start_loss + span_end_loss + sequence_loss).item()
                total_samples += batch_size
        return total_loss / total_samples if total_samples > 0 else 0
def create_sample_batch():
    """
    创建示例批次数据
    """
    # 模拟数据
    batch_size = 2
    seq_len = 10
    query_len = 5
    num_labels = 3  # BIO标签
    # 文档token IDs
    input_ids = torch.randint(100, 1000, (batch_size, seq_len))
    attention_mask = torch.ones((batch_size, seq_len))
    # 查询token IDs
    query_input_ids = torch.randint(100, 1000, (batch_size, query_len))
    query_attention_mask = torch.ones((batch_size, query_len))
    # 跨度标签 (二分类)
    span_start_labels = torch.randint(0, 2, (batch_size, seq_len)).float()
    span_end_labels = torch.randint(0, 2, (batch_size, seq_len)).float()
    # 序列标签
    sequence_labels = torch.randint(0, num_labels, (batch_size, seq_len))
    return (input_ids, attention_mask, 
            query_input_ids, query_attention_mask,
            span_start_labels, span_end_labels, sequence_labels)
# 使用示例
if __name__ == "__main__":
    # 初始化模型
    model = OpenClawSimplified(num_labels=3)
    # 创建训练器
    trainer = OpenClawTrainer(model)
    # 创建示例数据
    batch = create_sample_batch()
    # 训练步骤
    losses = trainer.train_step(batch)
    print("训练损失:", losses)
    # 提取跨度示例
    (input_ids, attention_mask, 
     query_input_ids, query_attention_mask, _, _, _) = batch
    spans = model.extract_spans(
        input_ids, attention_mask,
        query_input_ids, query_attention_mask
    )
    print("提取的跨度:", spans)
    # 模型参数量统计
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"总参数: {total_params:,}")
    print(f"可训练参数: {trainable_params:,}")

主要特点

  1. 双编码器架构: 同时编码文档和自然语言查询
  2. 查询注意力机制: 使用多头注意力让文档表示关注查询相关信息
  3. 条件层归一化: 根据查询信息调整文档表示
  4. 多任务学习: 同时进行跨度提取和序列标注
  5. 轻量级设计: 相比于完整OpenClaw,移除了复杂模块,保留核心功能

使用方式

# 1. 初始化模型
model = OpenClawSimplified(num_labels=3)  # 3表示BIO标签数量
# 2. 准备数据
# 需要文档和查询的tokenized输入
# 3. 训练模型
trainer = OpenClawTrainer(model)
for epoch in range(num_epochs):
    for batch in dataloader:
        losses = trainer.train_step(batch)
# 4. 提取信息
spans = model.extract_spans(
    document_tokens, 
    query_tokens, 
    threshold=0.5
)

这个简洁版OpenClaw保留了原模型的核心思想,同时大大简化了实现复杂度,适合学习、实验和小规模部署。

抱歉,评论功能暂时关闭!