Open Webui限制用户Tokens函数 解决vllm生产环境Tokens上限问题

在使用Open-Webui,模型使用vllm而不是ollama时,由于模型Token上限,会导致超Token数量而报错。在Open-Webui使用以下函数,可解决此问题。

"""
title: Context Length Filter
description: Truncate chat context length with 'token limit' and 'max turns', showing status while limit exceeded. System message and multimodal messages excluded.
author: Kejun Luo
version: 0.5
"""

import tiktoken
from pydantic import BaseModel, Field
from typing import Optional, Callable, Any, Awaitable

class Filter:
    class Valves(BaseModel):
        priority: int = Field(default=0, description="Priority level")
        max_turns: int = Field(
            default=25,
            description="Number of conversation turns to retain. Set '0' for unlimited",
        )
        token_limit: int = Field(
            default=10000,
            description="Number of token limit to retain. Set '0' for unlimited",
        )

    class UserValves(BaseModel):
        pass

    def __init__(self):
        self.valves = self.Valves()
        self.encoding = tiktoken.get_encoding("cl100k_base")

    async def inlet(
        self,
        body: dict,
        __event_emitter__: Callable[[Any], Awaitable[None]],
        __model__: Optional[dict] = None,
    ) -> dict:
        messages = body["messages"]
        chat_messages = messages[:]

        1. truncate turns
        if self.valves.max_turns > 0:
            current_turns = (len(chat_messages) - 1) // 2
            if current_turns > self.valves.max_turns:
                sent_msg_count = self.valves.max_turns * 2 + 1
                await self.show_exceeded_status(
                    __event_emitter__, self.valves.max_turns
                )
                chat_messages = chat_messages[-sent_msg_count:]

        1. truncate tokens
        if self.valves.token_limit > 0:
            filter_messages = []
            current_toks = 0
            for msg in reversed(chat_messages):
                toks = self.count_tokens(msg)
                not_user = msg.get("role", "") != "user"
                1. the first message must be a user message, so a user message should not be truncated.
                if (current_toks + toks > self.valves.token_limit) and not_user:
                    current_turns = len(filter_messages) // 2 + 1
                    await self.show_exceeded_status(__event_emitter__, current_turns)
                    break
                filter_messages.insert(0, msg)
                current_toks += toks
        else:
            filter_messages = chat_messages

        body["messages"] = filter_messages

        return body

    async def show_exceeded_status(
        self, __event_emitter__: Callable[[Any], Awaitable[None]], turn_count: int
    ) -> None:
        count = turn_count * 2 + 1
        await __event_emitter__(
            {
                "type": "status",
                "data": {
                    "description": f"Context limit reached - keeping last {count} messages",
                    "done": True,
                },
            }
        )

    def count_tokens(self, msg: dict) -> int:
        content = msg.get("content", "")
        total_tokens = 0

        if isinstance(content, list):
            1. Handle multi-modal content
            for item in content:
                if item.get("type") == "text":
                    text = item.get("text", "")
                    total_tokens += len(self.encoding.encode(text))
        elif isinstance(content, str):
            1. Handle text-only content
            total_tokens = len(self.encoding.encode(content))
        else:
            1. Handle unexpected content types
            total_tokens = 0

        return total_tokens
© 版权声明
THE END
若本文对您有帮助,欢迎点赞打赏转发
您的支持将是作者更新最大的动力
点赞24打赏 分享
评论 抢沙发

请登录后发表评论

    暂无评论内容