因为公司内部提供的大模型接口与常规的不太一致,市面上常见的 MCP Client 无法直接对接,这里根据 MCP github 上的 MCP chat client example 代码
( https://github.com/modelcontextprotocol/python-sdk/tree/main/examples/clients/simple-chatbot ),改写了一下,主要是支持自定义的大模型 API。
另外利用 rich、prompt_toolkit 美化了控制台的界面,不再是日志格式,并添加了更多 MCP 特性及聊天器必备命令。
增加命令:清理对话历史、切换大模型、查看 tool 列表及明细、查看 prompts 列表及明细、使用 prompts
增加了提示提示语意义不明,匹配多个,用户可选择执行功能(增加在了 system 的 prompts 中)
实现了提示语含多个 tool 连续调用处理
[1] 配置:servers_config.json,案例
{
"mcpServers": {
"sqlite": {
"command": "uvx",
"args": ["mcp-server-sqlite", "--db-path", "./test.db"],
"disabled": false
},
"puppeteer": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-puppeteer"],
"disabled": false
}
}
}
[2] 配置环境变量文件: .env,案例
OPEN_API_KEY="sk-proj-f********"
OPEN_PROXY="http://192.168.1.8:8080"
DEEPSEEK_API_KEY="sk-****"
DEEPSEEK_PROXY=
CUSTOM_OPEN_ID="xxxxxxxx"
CUSTOM_PROXY=
NO_PROXY="127.0.0.1,localhost"
依赖:python>=3.13
mcp==1.9.0
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-multipart==0.0.20
pytz==2025.2
rich==14.0.0
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
sse-starlette==2.3.5
starlette==0.46.2
typer==0.15.4
typing-inspection==0.4.1
typing_extensions==4.13.2
tzdata==2025.2
urllib3==2.4.0
uvicorn==0.34.2
wcwidth==0.2.13
prompt_toolkit==3.0.51
代码参考
# coding:utf-8
import asyncio
import json
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt
from prompt_toolkit.completion import WordCompleter
from rich.box import Box,ROUNDED
from prompt_toolkit import PromptSession
from prompt_toolkit.input import create_input
from rich.markdown import Markdown
from rich.live import Live
from rich.text import Text
import os
import shutil
from contextlib import AsyncExitStack
from typing import Any
import hashlib
import uuid
import httpx
from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from mcp.types import CallToolResult,TextContent,PromptArgument,GetPromptResult,PromptMessage,TextContent
class NoSlideBox(Box):
def __init__(self,):
super().__init__(
"╭─┬╮\n"
" \n"
"├─┼┤\n"
" \n"
"├─┼┤\n"
"├─┼┤\n"
" \n"
"╰─┴╯\n"
)
# 配置Console
console = Console(
color_system="auto",
)
error_console = Console(
stderr=True,
style="bold red",
)
server_console = Console(
style="dim blue",
)
client_models = {
"1":{
"ai_channel":"OpenAI",
"ai_model":"gpt-3.5-turbo",
"ai_api_url":"http://[自定义模型域名]/v1/chat/completion",
"ai_provider":"[自定义模型平台名]"
}
"2": {
"ai_channel":"OpenAI",
"ai_model":"gpt-4",
"ai_api_url":"https://api.openai.com/v1/chat/completions",
"ai_provider":"OpenAI"
},
"3": {
"ai_channel":"Deepseek",
"ai_model":"deepseek-chat",
"ai_api_url":"https://api.deepseek.com/v1/chat/completions",
"ai_provider":"Deepseek"
}
}
class Configuration:
"""Manages configuration and environment variables for the MCP client."""
def __init__(self):
"""Initialize configuration with environment variables."""
self.load_env()
self.open_api_key = os.getenv("OPEN_API_KEY","sk-proj-*****")
self.open_proxy = os.getenv("OPEN_PROXY",None) or None
self.deepseek_api_key = os.getenv("DEEPSEEK_API_KEY","*****")
self.deepseek_proxy = os.getenv("DEEPSEEK_PROXY",None) or None
self.custom_open_id= os.getenv("CUSTOM_OPEN_ID","*****")
self.custom_proxy = os.getenv("CUSTOM_PROXY",None) or None
@staticmethod
def load_env() -> None:
"""Load environment variables from .env file."""
load_dotenv()
@staticmethod
def load_config(file_path: str) -> dict[str, Any]:
"""Load server configuration from JSON file.
Args:
file_path: Path to the JSON configuration file.
Returns:
Dict containing server configuration.
Raises:
FileNotFoundError: If configuration file doesn't exist.
JSONDecodeError: If configuration file is invalid JSON.
"""
with open(file_path, "r") as f:
return json.load(f)
@property
def llm_api_key(self) -> str:
"""Get the LLM API key.
Returns:
The API key as a string.
Raises:
ValueError: If the API key is not found in environment variables.
"""
if not self.api_key:
raise ValueError("LLM_API_KEY not found in environment variables")
return self.api_key
# 定义全局的Configuration对象,以便在整个程序中共享配置
config = Configuration()
class Tool:
"""Represents a tool with its properties and formatting."""
def __init__(
self, name: str, description: str, input_schema: dict[str, Any],server_name: str = None
) -> None:
self.name: str = name
self.server_name = server_name
self.description: str = description
self.input_schema: dict[str, Any] = input_schema
def format_for_llm(self) -> str:
"""Format tool information for LLM.
Returns:
A formatted string describing the tool.
"""
args_desc = []
if "properties" in self.input_schema:
for param_name, param_info in self.input_schema["properties"].items():
arg_desc = (
f"- {param_name}: {param_info.get('description', 'No description')}"
)
if param_name in self.input_schema.get("required", []):
arg_desc += " (required)"
args_desc.append(arg_desc)
return f"""
Tool: {self.name}
Description: {self.description}
Arguments:
{chr(10).join(args_desc)}
"""
class MCPPrompt:
"""Represents a prompt with its properties and formatting."""
def __init__(
self, name: str, description: str, arguments: list[PromptArgument],server_name: str = None
) -> None:
self.name: str = name
self.server_name = server_name
self.description: str = description
self.arguments: list[PromptArgument] = arguments
def get_prompt_dict(self) -> dict:
"""Format prompt information for LLM.
Returns:
A formatted string describing the prompt.
"""
prompt={"ServerName":self.server_name,"PromptName": self.name}
if self.description:
prompt["Description"]= self.description
args_desc = []
if self.arguments:
for argObj in self.arguments:
arg_dict = {
"Argument": argObj.name,
}
if argObj.description:
arg_dict["Description"]= argObj.description
if argObj.required:
arg_dict["Required"]=True
args_desc.append(arg_dict)
if args_desc:
prompt["Arguments"]=args_desc
return prompt
@property
def format_for_rich(self) -> str:
"""Format prompt information for rich terminal."""
return f"+ [bold white]{self.server_name}[/bold white] - [bold yellow]{self.name}[/bold yellow]" + \
(f"\n > [magenta]arguments[/magenta]: {json.dumps([arg.name for arg in self.arguments],ensure_ascii=False)}" if self.arguments else "") + \
(f"\n > [magenta]description[/magenta]: {self.description}" if self.description else "")
class Server:
"""Manages MCP server connections and tool execution."""
def __init__(self, name: str, config: dict[str, Any]) -> None:
self.name: str = name
self.config: dict[str, Any] = config
self.stdio_context: Any | None = None
self.session: ClientSession | None = None
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.exit_stack: AsyncExitStack = AsyncExitStack()
async def initialize(self) -> None:
"""Initialize the server connection."""
command = (
shutil.which("npx")
if self.config.get("command") == "npx"
else self.config.get("command")
)
sseUrl = self.config.get("sseUrl") # 增加sse连接支持
cm = None
if command:
server_params = StdioServerParameters(
command=command,
args=self.config["args"],
env={**os.environ, **self.config["env"]}
if self.config.get("env")
else None,
)
cm=stdio_client(server_params)
elif sseUrl:
cm= sse_client(sseUrl)
else:
raise ValueError("The command or sseUrl must be a valid string and cannot be None.")
try:
stdio_transport = await self.exit_stack.enter_async_context(
cm
)
read, write = stdio_transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
error_console.print(f"Error initializing server {self.name}: {e}")
await self.cleanup()
raise
async def list_tools(self) -> list[Any]:
"""List available tools from the server.
Returns:
A list of available tools.
Raises:
RuntimeError: If the server is not initialized.
"""
if not self.session:
raise RuntimeError(f"Server {self.name} not initialized")
tools_response = await self.session.list_tools()
tools = []
for item in tools_response:
if isinstance(item, tuple) and item[0] == "tools":
tools.extend(
Tool(tool.name, tool.description, tool.inputSchema,server_name=self.name)
for tool in item[1]
)
return tools
async def execute_tool(
self,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""Execute a tool with retry mechanism.
Args:
tool_name: Name of the tool to execute.
arguments: Tool arguments.
Returns:
Tool execution result.
Raises:
RuntimeError: If server is not initialized.
"""
try:
if not self.session:
raise RuntimeError(f"Server {self.name} not initialized")
result = await self.session.call_tool(tool_name, arguments)
return result
except Exception as e:
return CallToolResult(content=[TextContent(type="text",text=f"Error executing tool {self.name} - {tool_name}: {e}.")],
isError=True)
async def list_prompts(self) -> list[MCPPrompt]:
"""Get prompts from the server.
Returns:
A list of prompts.
Raises:
RuntimeError: If the server is not initialized.
"""
if not self.session:
raise RuntimeError(f"Server {self.name} not initialized")
try:
prompts_response = await self.session.list_prompts()
prompts = []
for item in prompts_response:
if isinstance(item, tuple) and item[0] == "prompts":
prompts.extend(
MCPPrompt(
prompt.name,
prompt.description,
prompt.arguments,
server_name=self.name
)
for prompt in item[1]
)
return prompts
except Exception as e:
error_console.print(f"Error getting prompts from server {self.name}: {e}")
return []
async def get_prompt(self, prompt_name: str, arguments: dict[str, Any]) -> GetPromptResult | None:
"""Call a prompt with retry mechanism.
Args:
prompt_name: Name of the prompt to call.
arguments: Prompt arguments.
Returns:
Prompt execution result.
Raises:
RuntimeError: If server is not initialized.
"""
try:
if not self.session:
raise RuntimeError(f"Server {self.name} not initialized")
result = await self.session.get_prompt(prompt_name, arguments)
return result
except Exception as e:
return None
async def cleanup(self) -> None:
"""Clean up server resources."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
self.stdio_context = None
except Exception as e:
error_console.print(f"Error during cleanup of server {self.name}: {e}")
class LLMClient:
"""Manages communication with the LLM provider."""
def __init__(self, api_key: str,ai_channel: str = "OpenAI",ai_model: str = "gpt-4",
ai_api_url: str = "https://api.openai.com/v1/chat/completions",
ai_provider: str = "OpenAI",
http_proxy: str = None) -> None:
self.api_key: str = api_key
self.ai_channel = ai_channel
self.ai_model = ai_model
self.ai_api_url = ai_api_url
self.ai_provider = ai_provider
self.http_proxy = http_proxy
async def get_response(self, messages: list[dict[str, str]]) -> tuple[str,dict|None]:
"""Get a response from the LLM.
Args:
messages: A list of message dictionaries.
Returns:
The LLM's response as a string.
Raises:
httpx.RequestError: If the request to the LLM fails.
"""
url = self.ai_api_url
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
payload = {
"messages": messages,
"model": self.ai_model,
"temperature": 0.7,
"max_tokens": 4096,
"top_p": 1,
"stream": False,
"stop": None,
}
try:
async with httpx.AsyncClient(proxy=self.http_proxy) as client:
response = await client.post(url, headers=headers, json=payload,timeout=60)
response.raise_for_status()
data = response.json()
# console.print(f"LLM response: {json.dumps(data, indent=2, ensure_ascii=False)}")
usage = None
if data.get("usage"):
usage={}
usage["prompt_tokens"] = data["usage"].get("prompt_tokens")
usage["completion_tokens"] = data["usage"].get("completion_tokens")
usage["total_tokens"] = data["usage"].get("total_tokens")
return data["choices"][0]["message"]["content"],usage
except httpx.HTTPError as e:
error_message = f"Error getting LLM response: {str(e)}"
if isinstance(e, httpx.HTTPStatusError):
error_message = f"Error getting LLM response: {e.response.status_code} | {e.response.text}"
return (
f"I encountered an error: {error_message}. "
"Please try again or rephrase your request."
),None
class LLMClient2:
# 此处为公司内部的AI助手平台API调用客户端的实现
# 按照实际编写,此处仅供参考
def __init__(self,custom_open_id,ai_channel: str = "OpenAI",
ai_model: str = "gpt-3.5-turbo",
ai_provider: str = "CUSTOM",
ai_api_url: str = "http://[自定义模型域名]/v1/chat/completion",
http_proxy: str = None
) -> None:
self.custom_open_id= custom_open_id
self.ai_channel = ai_channel
self.ai_api_url = ai_api_url
self.ai_provider = ai_provider
self.ai_model = ai_model
self.http_proxy = http_proxy
def md5Sign(self,plainText):
md5Obj = hashlib.md5()
md5Obj.update((self.fosp_encode_key+plainText).encode('utf-8'))
return md5Obj.hexdigest()
async def get_response(self, messages: list[dict[str, str]]) -> tuple[str,dict|None]:
messagesStr=json.dumps(messages,indent=2)
headers={ # 请求openai需要的请求头
"Content-Type": "application/json",
"Open-ID": self.custom_open_id,
"Service-Code":"aigc-direct",
"Service-Type": '3'
}
payload = {
'channel': self.ai_channel,
'model': self.ai_model,
"prompt": messagesStr,
"promptEnc": self.md5Sign(messagesStr),
"maxTokens":16384,
"temperature":0.7,
"session":str(uuid.uuid4()),
}
try:
async with httpx.AsyncClient(proxy=self.http_proxy) as client:
response = await client.post( self.ai_api_url, headers=headers, json=payload,timeout=60)
response.raise_for_status()
data = response.json()
if not data.get('success') or data.get("code")!="AS0000":
error_console.print(f"Error getting LLM response: {data.get("message")}")
return data.get("message"),None
usage = None
if data.get("data",{}).get("usage"):
usage={}
usage["prompt_tokens"] = data["data"]["usage"].get("promptTokens")
usage["completion_tokens"] = data["data"]["usage"].get("completionTokens")
usage["total_tokens"] = data["data"]["usage"].get("totalTokens")
return data["data"]["text"], usage
except httpx.RequestError as e:
error_message = f"Error getting LLM response: {str(e)}"
if isinstance(e, httpx.HTTPStatusError):
error_message = f"Error getting LLM response: {e.response.status_code} | {e.response.text}"
return (
f"I encountered an error: {error_message}. "
"Please try again or rephrase your request."
),None
class ChatSession:
"""Orchestrates the interaction between user, LLM, and tools."""
def __init__(self, servers: list[Server], llm_client: LLMClient | LLMClient2) -> None:
self.servers: list[Server] = servers
self.llm_client: LLMClient | LLMClient2= llm_client
self.usage :dict | None = None
def showSysInfo(self,msg:str|Markdown,title:str,subtitle:str=None):
"""Show system info."""
sys_info_pannel = Panel(
msg,
title=title,
title_align="left",
style="bright_blue",
border_style="white",
subtitle = f"[gray37]{subtitle}[/gray37]" if subtitle else None,
subtitle_align="right",
padding=(1, 2)
)
console.print(sys_info_pannel)
def assistantResponse(self,msg:str|Markdown,subtitle:str=None):
panel = Panel(
msg,
title="[Assistant]",
title_align="left",
style="white",
border_style="green",
subtitle = f"[gray37]{subtitle}[/gray37]" if subtitle else None,
subtitle_align="left",
box = NoSlideBox(),
padding=(1, 2)
)
return panel
async def showAndGetAssistantResponse(self,call_llm: callable,subtitle:str=None):
"""Show Assistant response.
Args:call_llm (callable): A function that returns a tuple of (assistant_response, usage).
Returns: opitimazied_assistant_response (str|list|dict), assistant_response (str)
"""
with Live(auto_refresh=False) as live:
start_time = asyncio.get_running_loop().time()
task = asyncio.create_task(call_llm())
input_obj = create_input()
while not task.done():
key_press = input_obj.read_keys()
for key in key_press:
if key.data.upper() == 'P':
if not task.done():
task.cancel()
elapsed = asyncio.get_running_loop().time() - start_time
process_info = f"Waiting: ⌛ Cost [bold red]{elapsed:.2f}[/bold red] Sec"
assistant_panel = self.assistantResponse(process_info,"[Press P to Cancel]")
live.update(assistant_panel)
live.refresh()
await asyncio.sleep(0.2) # 降低 CPU 占用
try:
input_obj.close()
result,usage = task.result()
except (asyncio.CancelledError,Exception) as e:
result,usage = f"⚠️ You Cancelled Or Exception Occurred {e}",None
if usage: # 更新usage
self.usage = usage
try:
opt_result=json.loads(result)
show_result= f"""```json\n{json.dumps(opt_result, indent=2, ensure_ascii=False)}\n```"""
except json.JSONDecodeError:
show_result = opt_result = result
cost_info = f"🚩 Cost {elapsed:.2f} Sec"
assistant_panel = self.assistantResponse(Markdown(show_result),cost_info if not subtitle else f"{subtitle} | {cost_info}")
live.update(assistant_panel)
live.refresh()
return opt_result,result
def toolCalledPanel(self,toolName:str,args: None | dict,process_info:str=None,out_put:str=None,subtitle:str=None):
"""Show Tool called."""
process_msg=(
f"Tool Calling: [bold yellow]{toolName}[/bold yellow]\n"
f"Arguments: [bold light_sea_green]{args}[/bold light_sea_green]"
f"{('\n'+process_info) if process_info else ''}"
)
result_msg = Markdown(
f"""Tool Called: *{toolName}*
Arguments: *{args}*
{('\n'+out_put) if out_put else ''}"""
)
tool_panel = Panel(
process_msg if not out_put else result_msg,
title="[Tool]",
title_align="left",
border_style="magenta",
subtitle = f"[gray37]{subtitle}[/gray37]" if subtitle else None,
subtitle_align="left",
# box = NoSlideBox(),
padding=(1, 2)
)
return tool_panel
def switch_model(self, model_no: str):
"""Switch the model of the LLM client.
Args:
model_no: The model number to switch to.
Returns:
True if the model was switched successfully, False otherwise.
"""
model_info = client_models.get(model_no)
if not model_info:
error_console.print(f"Invalid model number: {model_no}")
return
if model_info["ai_provider"].lower() == "openai":
self.llm_client = LLMClient(api_key=config.open_api_key,
ai_channel=model_info["ai_channel"],
ai_model=model_info["ai_model"],
ai_api_url=model_info["ai_api_url"],
ai_provider=model_info["ai_provider"],
http_proxy=config.open_proxy)
elif model_info["ai_provider"].lower() == "deepseek":
self.llm_client = LLMClient(api_key=config.deepseek_api_key,
ai_channel=model_info["ai_channel"],
ai_model=model_info["ai_model"],
ai_api_url=model_info["ai_api_url"],
ai_provider=model_info["ai_provider"],
http_proxy=config.deepseek_proxy)
elif model_info["ai_provider"].lower() == "fosp":
self.llm_client = LLMClient2(config.fosp_open_id,
config.fosp_developer_secret,
config.fosp_encode_key,
ai_channel=model_info["ai_channel"],
ai_model=model_info["ai_model"],
ai_api_url=model_info["ai_api_url"],
ai_provider=model_info["ai_provider"],
http_proxy=config.fosp_proxy)
else:
error_console.print(f"Unsupport AI provider: {model_info['ai_provider']}")
async def cleanup_servers(self) -> None:
"""Clean up all servers properly."""
for server in reversed(self.servers):
try:
await server.cleanup()
except Exception as e:
server_console.print(f"Warning during final cleanup: {e}")
async def process_use_prompt(self, input_prompt: str) -> list[PromptMessage]|None:
""""
Process the use prompt and return the list of prompts.
Args:
input_prompt: The prompt to process.
Returns:
The list of prompts.
"""
selected_prompt = None
selected_server = None
for server in self.servers:
try:
mcp_prompts = await server.list_prompts()
except RuntimeError as e:
error_console.print(f"❌ Failed to list prompts from {server.name}: {e}")
continue
if not mcp_prompts:
continue
for mcp_prompt in mcp_prompts:
if mcp_prompt.name == input_prompt:
selected_prompt = mcp_prompt
selected_server = server
break
else:
continue
break
if not selected_prompt:
return None
self.showSysInfo(selected_prompt.format_for_rich,"[Selected Prompt]","If Arguments are present, please fill them.")
args = {}
if selected_prompt.arguments:
for arg in selected_prompt.arguments:
user_input = Prompt.ask(f"> Fill [bold bright_cyan]{arg.name}[/bold bright_cyan]").strip()
if not user_input and not arg.required:
continue
args[arg.name] = user_input
tool_resp_prompts = await selected_server.get_prompt(selected_prompt.name, args)
if not tool_resp_prompts:
return None
return tool_resp_prompts.messages
async def process_llm_response(self, llm_response: str|dict) -> str:
"""Process the LLM response and execute tools if needed.
Args:
llm_response: The response from the LLM.
Returns:
The result of tool execution or the original response.
"""
if isinstance(llm_response, dict):
tool_call=llm_response
else:
return llm_response
if "tool" in tool_call:
for server in self.servers:
tools = await server.list_tools()
if any(tool.name == tool_call["tool"] for tool in tools):
try:
with Live(auto_refresh=False) as live:
start_time = asyncio.get_running_loop().time()
task = asyncio.create_task(server.execute_tool(tool_call["tool"], tool_call.get("arguments")))
# 实时计算并显示耗时
while not task.done():
elapsed = asyncio.get_running_loop().time() - start_time
process_info = f"Running: 🕒 Cost [bold red]{elapsed:.2f}[/bold red] Sec"
ctPanel = self.toolCalledPanel(tool_call["tool"],tool_call.get("arguments"),process_info=process_info)
live.update(ctPanel)
live.refresh()
await asyncio.sleep(0.2) # 降低 CPU 占用
result = task.result()
if result.content and result.content[0].type=='text':
calledRst = result.content[0].text.strip()
try:
data=json.loads(calledRst)
out_put = f"""\n```json\n{json.dumps(data, indent=2, ensure_ascii=False)}\n```"""
except Exception as e:
out_put = f"\n{calledRst}"
else:
calledRst = f"{result.content}"
finish_info = f"{'❌' if result.isError else '🚩'} Cost {elapsed:.2f} Sec"
ctPanel = self.toolCalledPanel(tool_call["tool"],tool_call.get("arguments"),out_put=out_put,subtitle=finish_info)
live.update(ctPanel)
live.refresh()
return f"Tool execution result: {calledRst}"
except Exception as e:
error_msg = f"Error executing tool: {str(e)}"
error_console.print(f"{error_msg}")
return error_msg
return f"No server found with tool: {tool_call['tool']}"
return llm_response
def get_tool_details(self,all_tools:list[Tool],tool_part_name: str) -> str:
"""
Show the details of the tools that match the given part of the name.
Args:
tool_part_name: The part of the name to match.
"""
filter_tools = [tool for tool in all_tools if tool_part_name.lower() in tool.name.lower()]
if not filter_tools:
return "No tools found with that name."
tools_details = "\r\n".join([
f"+ [bold white]{tool.server_name}[/bold white] - [bold yellow]{tool.name}[/bold yellow]" + \
(f"\n > [magenta]description[/magenta]: {tool.description.strip()}" if tool.description else "")
for tool in filter_tools])
return tools_details
def get_prompt_details(self,all_prompts:list[MCPPrompt],prompt_part_name: str) -> str:
"""
Show the details of the prompts that match the given part of the name.
Args:
prompt_part_name: The part of the name to match.
"""
# Filter prompts that match the given part of the name
filter_prompts = [prompt for prompt in all_prompts if prompt_part_name.lower() in prompt.name.lower()]
# If no prompts match, return a message
if not filter_prompts:
return "No prompts found with that name."
# Join the details of the prompts into a string
prompts_details = "\r\n".join([prompt.format_for_rich for prompt in filter_prompts])
# Return the details
return prompts_details
async def start(self) -> None:
"""Main chat session handler."""
try:
for server in self.servers:
try:
await server.initialize()
except Exception as e:
error_console.print(f"Failed to initialize server: {e}")
await self.cleanup_servers()
return
all_tools = []
all_tools_nameFormat = []
all_prompts = []
all_prompts_nameFormat = []
for server in self.servers:
tools = await server.list_tools()
# 获取所有tools
all_tools.extend(tools)
# 根据server分类tools,将tools的名称每3个一行合并展示
ser_tools_nameFormat=""
for idx,tool in enumerate(tools):
ser_tools_nameFormat += f"+ {tool.name}" + ((" [white]|[/white] " if (idx+1) % 3 != 0 else "\n") if idx < len(tools)-1 else "")
all_tools_nameFormat.append({
"server_name": server.name,
"ser_tools_nameFormat": ser_tools_nameFormat
})
# console.log(await server.get_prompt("Debug Assistant", {"error":"the arg xx is not definined"}))
mcp_prompts = await server.list_prompts()
all_prompts.extend(mcp_prompts)
ser_prompts_nameFormat=""
for idx,prompt in enumerate(mcp_prompts):
ser_prompts_nameFormat += f"+ {prompt.name}" + ((" [white]|[/white] " if (idx+1) % 3 != 0 else "\n") if idx < len(mcp_prompts)-1 else "")
all_prompts_nameFormat.append({
"server_name": server.name,
"ser_prompts_nameFormat": ser_prompts_nameFormat
})
tools_description = "\n".join([tool.format_for_llm() for tool in all_tools]) # 核心prompts使用!!
tools_name = "\r\n".join([f"[bold yellow]{tool['server_name']}[/bold yellow]\n{tool['ser_tools_nameFormat']}" for tool in all_tools_nameFormat])
prompts_name = "\r\n".join([f"[bold yellow]{prompt['server_name']}[/bold yellow]\n{prompt['ser_prompts_nameFormat']}" for prompt in all_prompts_nameFormat])
models_options = "\n".join(f"{no}. {model['ai_channel']} [white]|[/white] {model['ai_model']} [white]|[/white] {model['ai_provider']}" for no, model in client_models.items())
commands = ("[bold yellow]clh[/bold yellow]:clean history [white]|[/white] "
"[bold yellow]cls[/bold yellow]:clean screen [white]|[/white] "
"[bold yellow]swm[/bold yellow]:switch model [white]|[/white] "
"[bold yellow]uml[/bold yellow]:use multiLine\n"
"[bold yellow]lst[/bold yellow]:list tools [white]|[/white] "
"[bold yellow]std *tool_name*[/bold yellow]:show tool details [white]|[/white] "
"[bold yellow]stu[/bold yellow]:show tokenUsage\n"
"[bold yellow]lsp[/bold yellow]:list prompts [white]|[/white] "
"[bold yellow]spd *prompt_name*[/bold yellow]:show prompt details [white]|[/white] "
"[bold yellow]usp[/bold yellow]:use prompt"
)
self.showSysInfo(f"{self.llm_client.ai_channel} [white]|[/white] {self.llm_client.ai_model} [white]|[/white] {self.llm_client.ai_provider}","[Current AI Model]") # 展示当前模型信息
self.showSysInfo(commands,"[Commands]","quit | exit")
system_message = (
"You are a helpful assistant with access to these tools:\n\n"
f"{tools_description}\n"
"Choose the appropriate tool based on the user's question. "
"If no tool is needed, reply directly.\n\n"
"IMPORTANT: When you need to use a tool, you must ONLY respond with "
"the exact JSON object format below, nothing else:\n"
"{\n"
' "tool": "tool-name",\n'
' "arguments": {\n'
' "argument-name": "value"\n'
" }\n"
"}\n\n"
"Attention: If multiple tools (quantity > 1) with similar meanings are identified due to the user's question:\n "
"list these tools and provide them with serial numbers (starting from 1 and incrementing sequentially) for the user to choose from, "
"you must ONLY respond with the exact JSON object format below, nothing else:\n"
"[{\n"
' "No.": 1,\n'
' "tool": "tool-name",\n'
' "arguments": {\n'
' "argument-name": "value"\n'
" }\n"
"},\n\n"
"{\n"
' "No.": 2,\n'
' "tool": "tool-name",\n'
' "arguments": {\n'
' "argument-name": "value"\n'
" }\n"
"}]\n\n"
"After receiving a tool's response:\n"
"1. Transform the raw data into a natural, conversational response\n"
"2. Keep responses concise but informative\n"
"3. Focus on the most relevant information\n"
"4. Use appropriate context from the user's question\n"
"5. Avoid simply repeating the raw data\n\n"
"Please use only the tools that are explicitly defined above."
)
messages = [{"role": "system", "content": system_message}]
while True:
try:
user_input = Prompt.ask("\n[bold bright_cyan][You 💬][/bold bright_cyan]").strip()
console.print("") # 增加一个空行
if not user_input:
error_console.print("⚠️ You Need Input Something...")
continue
if user_input.lower() in ["std","spd"]:
error_console.print("⚠️ Your Command is not complete...")
continue
if user_input.lower() in ["lst", "list tools"]:
self.showSysInfo(tools_name,"[MCP Tools]")
continue
if user_input.lower().startswith("std "):
tool_name = user_input[3:].strip()
tools_details = self.get_tool_details(all_tools,tool_name)
self.showSysInfo(tools_details,"[Tools Details]")
continue
if user_input.lower().startswith("spd "):
prompt_name = user_input[3:].strip()
prompts_details = self.get_prompt_details(all_prompts,prompt_name)
self.showSysInfo(prompts_details,"[Prompts Details]")
continue
if user_input.lower() in ["clh", "clean history"]:
del messages[1:]
self.usage =None
server_console.print("📢 Cleaned History...")
continue
if user_input.lower() in ["cls", "clean screen"]:
if os.name == 'posix': # Unix/Linux/Mac
print("\033c", end="")
elif os.name in ('nt', 'dos'): # Windows
os.system('cls')
self.showSysInfo(f"{self.llm_client.ai_channel} [white]|[/white] {self.llm_client.ai_model} [white]|[/white] {self.llm_client.ai_provider}","[Current AI Model]") # 展示当前模型信息
self.showSysInfo(commands,"[Commands]","quit | exit")
continue
if user_input.lower() in ["swm","switch model"]:
self.showSysInfo(models_options,"[AI Model Options]")
user_input = Prompt.ask("[bold cyan]Choose 🤔[/bold cyan]",choices=list(client_models.keys())).strip()
self.switch_model(user_input)
self.showSysInfo(f"{self.llm_client.ai_channel} [white]|[/white] {self.llm_client.ai_model} [white]|[/white] {self.llm_client.ai_provider}","[Current AI Model]") # 展示当前模型信息
del messages[1:]
self.usage =None
continue
if user_input.lower() in ["stu","show token usage"]:
if not self.usage:
usageInfo="No token usage information available"
else:
usageInfo = (f"Prompt Tokens: [yellow]{self.usage['prompt_tokens']}[/yellow]\n"
f"Completion Tokens: [yellow]{self.usage['completion_tokens']}[/yellow]\n"
f"Total Tokens: [yellow]{self.usage['total_tokens']}[/yellow]")
self.showSysInfo(usageInfo,"[Token Usage]")
continue
if user_input.lower() in ["uml","use multiline"]:
console.print("[bold]已开启多行输入[/bold](按Esc·Enter提交)")
session = PromptSession()
user_input = (await session.prompt_async("> ", multiline=True)).strip()
console.print("") # 增加一个空行
if not user_input:
console.print("⚠️ You Need Input Something...")
continue
if user_input.lower() in ["lsp", "list prompts"]:
self.showSysInfo(prompts_name,"[Prompt List]")
continue
if user_input.lower() in ["quit", "exit"]:
console.print("💻 Exiting...")
break
# 使用server提供的prompts
if user_input.lower() in ["usp", "use prompt"]:
console.print("[bold]输入可联想PromptName[/bold](按↑↓选择·Enter提交)")
word_completer = WordCompleter([prompt.name for prompt in all_prompts], ignore_case=True,match_middle=True)
session = PromptSession()
prm_input = (await session.prompt_async("> ", completer=word_completer)).strip()
console.print("") # 增加一个空行
if prm_input not in word_completer.words:
console.print("⚠️ Invalid Prompt Name...")
continue
prompt_messages = await self.process_use_prompt(prm_input)
if not prompt_messages:
console.print(f"⚠️ Can't get Prompt - {prm_input} from MCP Servers...")
continue
show_prompts = []
for prompt_message in prompt_messages:
if prompt_message.role.lower() == "system" and isinstance(prompt_message.content, TextContent):
messages.append({"role": "system", "content": prompt_message.content.text})
show_prompts.append(f"[blue]System[/blue]: [white]{prompt_message.content.text}[/white]")
if prompt_message.role.lower() == "user" and isinstance(prompt_message.content, TextContent):
messages.append({"role": "user", "content": prompt_message.content.text})
show_prompts.append(f"[bright_cyan]User[/bright_cyan]: [white]{prompt_message.content.text}[/white]")
if prompt_message.role.lower() == "assistant" and isinstance(prompt_message.content, TextContent):
messages.append({"role": "assistant", "content": prompt_message.content.text})
show_prompts.append(f"[green]Assistant[/green]: [white]{prompt_message.content.text}[/white]")
if not show_prompts:
console.print(f"⚠️ No text message from Prompt - {prm_input} from MCP Servers...")
continue
console.print("") # 增加一个空行
self.showSysInfo("\n".join(show_prompts),"[Used Prompt]")
else:
messages.append({"role": "user", "content": user_input})
# 尝试获取AI的响应
llm_response,orig_llm_response = await self.showAndGetAssistantResponse(lambda: self.llm_client.get_response(messages))
# 处理MCP Tool调用
while True:
messages.append({"role": "assistant", "content": orig_llm_response})
try:
if isinstance(llm_response, list):
choices=[str(x["No."]) for x in llm_response if x.get("No.")]
if choices and choices[0]:
user_input = Prompt.ask("[bold cyan]Choose NO.(0 to skip)🤔[/bold cyan]",choices=['0']+choices).strip()
selectTools=list(filter(lambda x: str(x.get("No."))==str(user_input),llm_response))
if selectTools:
llm_response = selectTools[0]
messages.append({"role": "user", "content": f"I choose No.{user_input}"})
else:
llm_response = None
messages.append({"role": "user", "content": f"I don't need any tools"})
break
result = await self.process_llm_response(llm_response)
if result != llm_response:
if isinstance(self.llm_client,LLMClient2):
messages.append({"role": "assistant", "content": "I have got the tool response from system, "+result})
else:
messages.append({"role": "system", "content": result})
llm_response,orig_llm_response = await self.showAndGetAssistantResponse(lambda: self.llm_client.get_response(messages),"[Tool Response Summary]")
else:
break
except Exception as e:
error_console.print(f"{e}")
break
except KeyboardInterrupt:
console.print("💻 Exiting...")
break
finally:
await self.cleanup_servers()
async def main() -> None:
server_config = config.load_config("servers_config.json")
servers = [
Server(name, srv_config)
for name, srv_config in server_config["mcpServers"].items() if not srv_config.get("disabled")
]
llm_client = LLMClient2(config.custom_open_id,http_proxy=config.custom_proxy)
chat_session = ChatSession(servers, llm_client)
await chat_session.start()
if __name__ == "__main__":
asyncio.run(main())