AtomStorm/01_creativity/tool_use.ipynb
2025-02-05 21:33:11 +08:00

276 lines
12 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Openai Function Call \n",
"https://platform.openai.com/docs/guides/function-calling"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"User>\t How's the weather in Hangzhou?\n",
"[ChatCompletionMessageToolCall(id='call_zvxBS1zoCa6X9fJzE48BYo2z', function=Function(arguments='{\"location\":\"Hangzhou, China\"}', name='get_weather', parameters=None), type='function')]\n",
"https://api.weatherapi.com/v1/current.json?key=b7e69154e60b427586983619250502&q=Hangzhou, China\n",
"Tool call>\t Temperature: 13.0°C, Condition: Partly cloudy\n",
"Model>\t The current weather in Hangzhou, China is partly cloudy with a temperature of 13.0°C.\n"
]
}
],
"source": [
"from openai import OpenAI\n",
"import json\n",
"import requests\n",
"\n",
"# 定义获取天气的工具\n",
"def get_weather(location: str):\n",
" \"\"\"调用天气 API 获取指定地点的天气信息\"\"\"\n",
" url = f\"https://api.weatherapi.com/v1/current.json?key=b7e69154e60b427586983619250502&q={location}\"\n",
" print(url)\n",
" response = requests.get(url)\n",
" if response.status_code == 200:\n",
" data = response.json()\n",
" return f\"Temperature: {data['current']['temp_c']}°C, Condition: {data['current']['condition']['text']}\"\n",
" return \"Weather data not found\"\n",
"\n",
"def send_messages(messages):\n",
" response = client.chat.completions.create(\n",
" model=\"gpt-4o\",\n",
" messages=messages,\n",
" tools=tools, # 指定可调用的函数\n",
" )\n",
" return response.choices[0].message\n",
"\n",
"client = OpenAI(\n",
" api_key=\"sk-lXFW7Bl1ruw2qmHu287e979847354601A07fE2D85a567bD7\", # pass litellm proxy key, if you're using virtual keys\n",
" base_url=\"https://yunwu.ai/v1/\" # litellm-proxy-base url\n",
")\n",
"\n",
"tools = [{\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_weather\",\n",
" \"description\": \"Get current temperature for a given location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"City and country e.g. Bogotá, Colombia\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ],\n",
" \"additionalProperties\": False\n",
" },\n",
" \"strict\": True\n",
" }\n",
"}]\n",
"# step1\n",
"messages = [{\"role\": \"user\", \"content\": \"How's the weather in Hangzhou?\"}]\n",
"message = send_messages(messages)\n",
"print(f\"User>\\t {messages[0]['content']}\")\n",
"print(message.tool_calls)\n",
"\n",
"# step2\n",
"tool_call = message.tool_calls[0]\n",
"\n",
"#step3 tool call chioce to local function call\n",
"args = json.loads(tool_call.function.arguments)\n",
"result = get_weather(args[\"location\"])\n",
"print(\"Tool call>\\t\", result)\n",
"\n",
"# step4\n",
"messages.append(message) # append model's function call message\n",
"messages.append({ # append result message\n",
" \"role\": \"tool\",\n",
" \"tool_call_id\": tool_call.id,\n",
" \"content\": result\n",
"})\n",
"\n",
"message = send_messages(messages)\n",
"\n",
"# step5\n",
"print(\"Model>\\t\", message.content)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Qwen Tool Use Example"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"------------------------------------------------------------\n",
"\n",
"第1轮大模型输出信息{'id': 'chatcmpl-91008e14-9254-98b9-8b14-da6632f6c711', 'choices': [{'finish_reason': 'tool_calls', 'index': 0, 'logprobs': None, 'message': {'content': '', 'refusal': None, 'role': 'assistant', 'audio': None, 'function_call': None, 'tool_calls': [{'id': 'call_5ae24ff6a07e4a1093fff6', 'function': {'arguments': '{}', 'name': 'get_current_time'}, 'type': 'function', 'index': 0}]}}], 'created': 1738762202, 'model': 'qwen-max-2025-01-25', 'object': 'chat.completion', 'service_tier': None, 'system_fingerprint': None, 'usage': {'completion_tokens': 13, 'prompt_tokens': 219, 'total_tokens': 232, 'completion_tokens_details': None, 'prompt_tokens_details': None}}\n",
"\n",
"工具输出信息当前时间2025-02-05 21:29:58。\n",
"\n",
"------------------------------------------------------------\n",
"第2轮大模型输出信息{'content': '现在的时间是2025年2月5日晚上9点29分58秒。', 'refusal': None, 'role': 'assistant', 'audio': None, 'function_call': None, 'tool_calls': None}\n",
"\n",
"最终答案现在的时间是2025年2月5日晚上9点29分58秒。\n"
]
}
],
"source": [
"from openai import OpenAI\n",
"from datetime import datetime\n",
"import json\n",
"import os\n",
"\n",
"client = OpenAI(\n",
" # 若没有配置环境变量请用百炼API Key将下行替换为api_key=\"sk-xxx\",\n",
" api_key=\"sk-487d366becc14df0a58ff4d4559fabf4\",\n",
" base_url=\"https://dashscope.aliyuncs.com/compatible-mode/v1\", # 填写DashScope SDK的base_url\n",
")\n",
"\n",
"# 定义工具列表模型在选择使用哪个工具时会参考工具的name和description\n",
"tools = [\n",
" # 工具1 获取当前时刻的时间\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_time\",\n",
" \"description\": \"当你想知道现在的时间时非常有用。\",\n",
" # 因为获取当前时间无需输入参数因此parameters为空字典\n",
" \"parameters\": {}\n",
" }\n",
" }, \n",
" # 工具2 获取指定城市的天气\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"当你想查询指定城市的天气时非常有用。\",\n",
" \"parameters\": { \n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" # 查询天气时需要提供位置因此参数设置为location\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"城市或县区,比如北京市、杭州市、余杭区等。\"\n",
" }\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ]\n",
" }\n",
" }\n",
"]\n",
"\n",
"# 模拟天气查询工具。返回结果示例:“北京今天是雨天。”\n",
"def get_current_weather(location):\n",
" return f\"{location}今天是雨天。 \"\n",
"\n",
"# 查询当前时间的工具。返回结果示例“当前时间2024-04-15 17:15:18。“\n",
"def get_current_time():\n",
" # 获取当前日期和时间\n",
" current_datetime = datetime.now()\n",
" # 格式化当前日期和时间\n",
" formatted_time = current_datetime.strftime('%Y-%m-%d %H:%M:%S')\n",
" # 返回格式化后的当前时间\n",
" return f\"当前时间:{formatted_time}。\"\n",
"\n",
"# 封装模型响应函数\n",
"def get_response(messages):\n",
" completion = client.chat.completions.create(\n",
" model=\"qwen-max-2025-01-25\",\n",
" messages=messages,\n",
" tools=tools\n",
" )\n",
" return completion.model_dump()\n",
"\n",
"def call_with_messages():\n",
" print('\\n')\n",
" messages = [\n",
" {\n",
" \"content\": input('现在几点了?'), # 提问示例:\"现在几点了?\" \"一个小时后几点\" \"北京天气如何?\"\n",
" \"role\": \"user\"\n",
" }\n",
" ]\n",
" print(\"-\"*60)\n",
" # 模型的第一轮调用\n",
" i = 1\n",
" first_response = get_response(messages)\n",
" assistant_output = first_response['choices'][0]['message']\n",
" print(f\"\\n第{i}轮大模型输出信息:{first_response}\\n\")\n",
" if assistant_output['content'] is None:\n",
" assistant_output['content'] = \"\"\n",
" messages.append(assistant_output)\n",
" # 如果不需要调用工具,则直接返回最终答案\n",
" if assistant_output['tool_calls'] == None: # 如果模型判断无需调用工具则将assistant的回复直接打印出来无需进行模型的第二轮调用\n",
" print(f\"无需调用工具,我可以直接回复:{assistant_output['content']}\")\n",
" return\n",
" # 如果需要调用工具,则进行模型的多轮调用,直到模型判断无需调用工具\n",
" while assistant_output['tool_calls'] != None:\n",
" # 如果判断需要调用查询天气工具,则运行查询天气工具\n",
" if assistant_output['tool_calls'][0]['function']['name'] == 'get_current_weather':\n",
" tool_info = {\"name\": \"get_current_weather\", \"role\":\"tool\"}\n",
" # 提取位置参数信息\n",
" location = json.loads(assistant_output['tool_calls'][0]['function']['arguments'])['location']\n",
" tool_info['content'] = get_current_weather(location)\n",
" # 如果判断需要调用查询时间工具,则运行查询时间工具\n",
" elif assistant_output['tool_calls'][0]['function']['name'] == 'get_current_time':\n",
" tool_info = {\"name\": \"get_current_time\", \"role\":\"tool\"}\n",
" tool_info['content'] = get_current_time()\n",
" print(f\"工具输出信息:{tool_info['content']}\\n\")\n",
" print(\"-\"*60)\n",
" messages.append(tool_info)\n",
" assistant_output = get_response(messages)['choices'][0]['message']\n",
" if assistant_output['content'] is None:\n",
" assistant_output['content'] = \"\"\n",
" messages.append(assistant_output)\n",
" i += 1\n",
" print(f\"第{i}轮大模型输出信息:{assistant_output}\\n\")\n",
" print(f\"最终答案:{assistant_output['content']}\")\n",
"\n",
"call_with_messages()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}