Agent 工具怎么写?从原理到最佳实践
你用 LangGraph 跑通了官方示例,现在想加自己的工具。但官方文档讲得很浅:加个 @tool 装饰器就完事了?
实际上,写好一个工具远不止这些。错误处理怎么做?参数验证怎么做?安全性怎么保证?性能怎么优化?
这篇文章,我从原理到最佳实践,深入讲解 Agent 工具开发。
工具的本质是什么?
工具 = 输入 Schema + 执行逻辑 + 输出格式。
┌─────────────────────────────────────────────────┐
│ Agent 调用工具 │
└─────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────┐
│ 1. LLM 根据 Schema 生成参数 │
│ {"path": "README.md", "lines": 10} │
└─────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────┐
│ 2. 工具执行 │
│ - 参数验证 │
│ - 权限检查 │
│ - 执行操作 │
│ - 错误处理 │
└─────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────┐
│ 3. 返回结果 │
│ "README.md 的前 10 行是..." │
└─────────────────────────────────────────────────┘
LLM 不知道工具内部怎么实现,只知道:
- 工具能做什么(从 description)
- 需要什么参数(从 inputSchema)
工具 Schema 详解
inputSchema 的作用
inputSchema 告诉 LLM:
- 需要传哪些参数
- 参数的类型是什么
- 哪些参数是必需的
from langchain_core.tools import tool
from typing import Optional
@tool
def read_file(
path: str,
start_line: int = 1,
end_line: Optional[int] = None,
encoding: str = "utf-8",
) -> str:
"""读取文件的指定行
Args:
path: 文件路径(必需)
start_line: 起始行号,从 1 开始(默认 1)
end_line: 结束行号,不指定则读到文件末尾
encoding: 文件编码(默认 utf-8)
Returns:
文件内容
"""
pass
生成的 Schema:
{
"name": "read_file",
"description": "读取文件的指定行",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "文件路径(必需)"
},
"start_line": {
"type": "integer",
"description": "起始行号,从 1 开始(默认 1)",
"default": 1
},
"end_line": {
"type": "integer",
"description": "结束行号,不指定则读到文件末尾"
},
"encoding": {
"type": "string",
"description": "文件编码(默认 utf-8)",
"default": "utf-8"
}
},
"required": ["path"]
}
}
LLM 如何使用这个 Schema:
用户:读取 README.md 的前 10 行
LLM 推理:
- path = "README.md"(必需参数)
- end_line = 10(用户指定了"前 10 行")
- start_line = 1(默认值)
- encoding = "utf-8"(默认值)
生成调用:
read_file(path="README.md", end_line=10)
description 的写法
description 决定 LLM 是否会正确使用工具。
❌ 错误示例:
@tool
def search(query: str) -> str:
"""搜索"""
pass
问题:
- 搜索什么?文件?网络?数据库?
- 返回什么格式?
- 有什么限制?
✅ 正确示例:
@tool
def search_files(keyword: str, directory: str = ".", max_results: int = 20) -> str:
"""在目录下搜索包含关键词的文件
功能:
- 搜索指定目录下所有文件
- 返回包含关键词的文件列表
- 支持递归搜索子目录
限制:
- 最多返回 max_results 个结果
- 只搜索文本文件(跳过二进制文件)
- 单文件大小限制 10MB
Args:
keyword: 搜索关键词
directory: 搜索目录,默认当前目录
max_results: 最大结果数,默认 20
Returns:
匹配的文件列表,格式:
- 文件路径: 匹配行数
"""
pass
工具实现最佳实践
1. 参数验证
永远不要相信 LLM 生成的参数。
from pathlib import Path
@tool
def read_file(path: str, max_size_mb: int = 10) -> str:
"""读取文件内容"""
# 参数验证
if not path:
return "错误:文件路径不能为空"
if max_size_mb <= 0 or max_size_mb > 100:
return "错误:max_size_mb 必须在 1-100 之间"
# 路径验证
try:
file_path = Path(path).resolve()
except Exception as e:
return f"错误:无效的路径 {path}"
# 文件存在性检查
if not file_path.exists():
return f"错误:文件不存在 {path}"
if not file_path.is_file():
return f"错误:{path} 不是文件"
# 文件大小检查
size_mb = file_path.stat().st_size / 1024 / 1024
if size_mb > max_size_mb:
return f"错误:文件过大 ({size_mb:.1f}MB),超过限制 ({max_size_mb}MB)"
# 执行读取
try:
return file_path.read_text(encoding="utf-8")
except UnicodeDecodeError:
return "错误:文件编码不支持,可能不是文本文件"
except PermissionError:
return f"错误:没有权限读取 {path}"
except Exception as e:
return f"错误:{type(e).__name__}: {str(e)}"
2. 安全检查
工具是 Agent 的"手脚",必须限制它的能力。
import os
from pathlib import Path
# 定义安全边界
ALLOWED_DIRECTORIES = [
Path.home() / "projects",
Path("/tmp"),
]
DENIED_PATTERNS = [
".env",
".pem",
".key",
"id_rsa",
"credentials",
]
def is_path_allowed(path: Path) -> tuple[bool, str]:
"""检查路径是否允许访问"""
# 解析绝对路径
try:
abs_path = path.resolve()
except Exception as e:
return False, f"无效路径: {e}"
# 检查是否在允许的目录内
in_allowed = any(
str(abs_path).startswith(str(allowed))
for allowed in ALLOWED_DIRECTORIES
)
if not in_allowed:
return False, f"路径不在允许的目录内: {abs_path}"
# 检查敏感文件
path_str = str(abs_path).lower()
for pattern in DENIED_PATTERNS:
if pattern in path_str:
return False, f"不允许访问敏感文件: {pattern}"
return True, "OK"
@tool
def write_file(path: str, content: str) -> str:
"""写入文件"""
# 安全检查
file_path = Path(path)
allowed, reason = is_path_allowed(file_path)
if not allowed:
return f"安全拒绝: {reason}"
# 执行写入
try:
file_path.write_text(content, encoding="utf-8")
return f"成功: 已写入 {path}"
except Exception as e:
return f"错误: {str(e)}"
3. 资源限制
防止工具消耗过多资源。
import time
import signal
from contextlib import contextmanager
class TimeoutError(Exception):
pass
@contextmanager
def time_limit(seconds: int):
"""执行时间限制"""
def signal_handler(signum, frame):
raise TimeoutError(f"执行超时 ({seconds}秒)")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
@tool
def execute_command(command: str, timeout: int = 30) -> str:
"""执行 shell 命令"""
# 安全检查
dangerous_commands = ["rm -rf", "sudo", "chmod 777", "> /dev/sda"]
for dangerous in dangerous_commands:
if dangerous in command:
return f"安全拒绝: 不允许执行危险命令"
# 资源限制
try:
with time_limit(timeout):
import subprocess
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=timeout,
)
return result.stdout or result.stderr
except TimeoutError:
return f"错误: 命令执行超时 ({timeout}秒)"
except Exception as e:
return f"错误: {str(e)}"
4. 结构化输出
返回结构化数据,方便 LLM 解析。
import json
from typing import Any
@tool
def analyze_code(file_path: str) -> str:
"""分析代码文件"""
# 分析代码
result = {
"file": file_path,
"language": detect_language(file_path),
"lines": 0,
"functions": [],
"classes": [],
"imports": [],
"complexity": "low",
}
# ... 分析逻辑 ...
# 返回 JSON 字符串
return json.dumps(result, ensure_ascii=False, indent=2)
# LLM 看到的输出:
"""
{
"file": "main.py",
"language": "Python",
"lines": 150,
"functions": ["main", "process_data", "save_result"],
"classes": ["DataProcessor"],
"imports": ["json", "pathlib", "typing"],
"complexity": "medium"
}
"""
工具注册和管理
简单方案:列表管理
# 定义工具列表
TOOLS = [
read_file,
write_file,
list_files,
search_files,
]
# 绑定到 LLM
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4")
llm_with_tools = llm.bind_tools(TOOLS)
进阶方案:工具注册表
from typing import Callable, Any
from dataclasses import dataclass, field
from enum import Enum
class ToolCategory(Enum):
READ = "read" # 只读操作
WRITE = "write" # 写入操作
EXECUTE = "execute" # 执行命令
NETWORK = "network" # 网络请求
@dataclass
class RegisteredTool:
tool: Callable
category: ToolCategory
cost: float = 0.0 # 预估成本
risk: str = "low" # low / medium / high
class ToolRegistry:
"""工具注册表"""
def __init__(self):
self._tools: dict[str, RegisteredTool] = {}
def register(
self,
tool: Callable,
category: ToolCategory,
cost: float = 0.0,
risk: str = "low",
):
"""注册工具"""
name = tool.name
self._tools[name] = RegisteredTool(
tool=tool,
category=category,
cost=cost,
risk=risk,
)
def get_tool(self, name: str) -> RegisteredTool | None:
"""获取工具"""
return self._tools.get(name)
def get_tools_by_category(self, category: ToolCategory) -> list[Callable]:
"""按类别获取工具"""
return [
rt.tool for rt in self._tools.values()
if rt.category == category
]
def get_tools_by_risk(self, max_risk: str = "medium") -> list[Callable]:
"""按风险等级获取工具"""
risk_levels = {"low": 1, "medium": 2, "high": 3}
max_level = risk_levels[max_risk]
return [
rt.tool for rt in self._tools.values()
if risk_levels[rt.risk] <= max_level
]
def get_langchain_tools(self) -> list[Callable]:
"""获取 LangChain 格式的工具列表"""
return [rt.tool for rt in self._tools.values()]
# 使用
registry = ToolRegistry()
registry.register(read_file, ToolCategory.READ, cost=0.001, risk="low")
registry.register(write_file, ToolCategory.WRITE, cost=0.002, risk="medium")
registry.register(execute_command, ToolCategory.EXECUTE, cost=0.01, risk="high")
# 按风险等级选择工具
safe_tools = registry.get_tools_by_risk("low") # 只允许低风险工具
异步工具
LLM 调用是异步的,工具也应该是异步的。
import aiofiles
import asyncio
@tool
async def read_file_async(path: str) -> str:
"""异步读取文件"""
try:
async with aiofiles.open(path, mode='r', encoding='utf-8') as f:
content = await f.read()
return content
except Exception as e:
return f"错误: {str(e)}"
# 在 Agent 中使用
async def execute_node(state: AgentState) -> dict:
"""执行节点(异步版本)"""
results = []
for tc in state["tool_calls"]:
tool = get_tool(tc["name"])
# 异步执行
if asyncio.iscoroutinefunction(tool.invoke):
result = await tool.ainvoke(tc["args"])
else:
result = tool.invoke(tc["args"])
results.append(result)
return {"tool_results": results}
工具测试
工具必须有单元测试。
import pytest
from tempfile import TemporaryDirectory
from pathlib import Path
def test_read_file_exists():
"""测试读取存在的文件"""
with TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.txt"
test_file.write_text("Hello, World!")
result = read_file.invoke({"path": str(test_file)})
assert "Hello, World!" in result
assert "错误" not in result
def test_read_file_not_exists():
"""测试读取不存在的文件"""
result = read_file.invoke({"path": "/nonexistent/file.txt"})
assert "错误" in result
assert "不存在" in result
def test_read_file_too_large():
"""测试读取超大文件"""
with TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "large.txt"
# 创建 15MB 文件
test_file.write_bytes(b"x" * (15 * 1024 * 1024))
result = read_file.invoke({
"path": str(test_file),
"max_size_mb": 10,
})
assert "错误" in result
assert "过大" in result
def test_read_file_binary():
"""测试读取二进制文件"""
with TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "binary.bin"
test_file.write_bytes(b"\x00\x01\x02\x03")
result = read_file.invoke({"path": str(test_file)})
assert "错误" in result or "编码" in result
# 运行测试
# pytest tests/test_tools.py -v
我踩过的真实坑
坑一:参数类型错误
现象:LLM 生成了错误的参数类型。
# LLM 生成的调用
read_file.invoke({"path": 123}) # 数字而不是字符串
# 工具内部报错
# AttributeError: 'int' object has no attribute 'exists'
解决:参数强制转换 + 类型检查。
def read_file(path: str) -> str:
path = str(path) # 强制转换
if not isinstance(path, str):
return "错误: path 必须是字符串"
# ... 继续处理
坑二:工具返回 None
现象:工具没有返回值,LLM 不知道发生了什么。
@tool
def write_file(path: str, content: str):
"""写入文件"""
Path(path).write_text(content)
# 没有 return!
解决:永远返回明确的确认信息。
@tool
def write_file(path: str, content: str) -> str:
"""写入文件"""
Path(path).write_text(content)
return f"成功: 已写入 {path},共 {len(content)} 字符" # 明确返回
坑三:工具抛异常
现象:工具抛异常,Agent 崩溃。
@tool
def read_file(path: str) -> str:
with open(path) as f: # 可能抛 FileNotFoundError
return f.read()
解决:捕获所有异常,返回错误信息。
@tool
def read_file(path: str) -> str:
try:
with open(path) as f:
return f.read()
except FileNotFoundError:
return f"错误: 文件不存在 {path}"
except Exception as e:
return f"错误: {type(e).__name__}: {str(e)}"
下一步行动
- 定义你的工具列表:Agent 需要什么能力?
- 写第一个工具:从最简单的开始
- 加参数验证和安全检查:不要相信任何输入
- 写单元测试:每个工具至少 3 个测试用例
工具是 Agent 的"手脚",写好工具,Agent 才能做真正有用的事。
工具不是简单的函数包装,而是需要考虑:参数验证、安全检查、资源限制、错误处理、结构化输出。这些做好了,Agent 才可靠。