跳转到正文
zeno's blog
返回

axum(三):中间件与生产实践-tower 原生的 Web 应用

专题: axum

Table of contents

Open Table of contents

TL;DR

axum 不发明中间件系统——Router::layer() 直接接受任何 tower::Layer。tower-http 的 Trace/Compression/CORS/Timeout 全部开箱即用。from_fn 让你用普通 async 函数写中间件而不需要实现 Service + Layer。注意 .layer() 只影响它之前添加的路由,顺序是洋葱模型(后加的在外层)。


中间件接入

tower 中间件:直接 .layer()

use tower_http::trace::TraceLayer;
use tower_http::compression::CompressionLayer;
use tower_http::cors::CorsLayer;
use tower::timeout::TimeoutLayer;

let app = Router::new()
    .route("/api/users", get(list_users))
    .route("/api/items", get(list_items))
    .layer(TraceLayer::new_for_http())
    .layer(CompressionLayer::new())
    .layer(CorsLayer::permissive())
    .layer(TimeoutLayer::new(Duration::from_secs(30)));

函数式中间件:from_fn

不想实现 Service + Layer trait 时,用 axum::middleware::from_fn

use axum::middleware::{self, Next};

async fn auth_middleware(req: Request, next: Next) -> Result<Response, StatusCode> {
    let token = req.headers()
        .get("Authorization")
        .and_then(|v| v.to_str().ok())
        .ok_or(StatusCode::UNAUTHORIZED)?;

    // 验证 token...
    if !is_valid(token) {
        return Err(StatusCode::FORBIDDEN);
    }

    Ok(next.run(req).await)
}

let app = Router::new()
    .route("/protected", get(handler))
    .layer(middleware::from_fn(auth_middleware));

需要访问状态时用 from_fn_with_state

async fn auth(
    State(state): State<AppState>,
    req: Request,
    next: Next,
) -> Response {
    // state 可用
    next.run(req).await
}

let app = Router::new()
    .route("/", get(handler))
    .layer(middleware::from_fn_with_state(state.clone(), auth))
    .with_state(state);

Extractor 作为中间件

from_extractor 把提取器变成中间件——提取成功则继续,失败则返回 rejection:

// 定义一个提取器用于鉴权
struct RequireAuth(AuthUser);

impl<S: Send + Sync> FromRequestParts<S> for RequireAuth {
    type Rejection = StatusCode;
    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
        // 从 header 中提取并验证 token
        // 成功返回 RequireAuth(user),失败返回 401
    }
}

let app = Router::new()
    .route("/admin", get(admin_handler))
    .layer(middleware::from_extractor::<RequireAuth>());

中间件作用范围和顺序

layer() vs route_layer()

方法作用范围404 处理
Router::layer()所有路由(包括未匹配的 404)中间件会对 404 响应也执行
Router::route_layer()仅匹配到的路由404 不经过此中间件

route_layer() 适合鉴权中间件——不应该对未匹配路径执行鉴权逻辑。

洋葱模型

.layer() 的顺序是后加的在外层(与 tower ServiceBuilder 相反):

Router::new()
    .layer(layer_c)   // 最外层:请求最先经过,响应最后经过
    .layer(layer_b)   // 中间层
    .layer(layer_a)   // 最内层:请求最后经过,响应最先经过

// 请求流: → layer_c → layer_b → layer_a → handler
// 响应流: ← layer_c ← layer_b ← layer_a ← handler

ServiceBuilder 时顺序翻转(自上而下):

.layer(
    ServiceBuilder::new()
        .layer(layer_a)   // 最外层
        .layer(layer_b)
        .layer(layer_c)   // 最内层
)

作用范围陷阱

.layer() 只影响之前添加的路由

Router::new()
    .route("/public", get(public_handler))    // 有中间件
    .layer(auth_layer)
    .route("/health", get(health_check))      // 没有中间件!

这是设计选择——允许选择性应用中间件。但它经常让人意外。

常用 tower-http 中间件

中间件用途Cargo feature
TraceLayer请求/响应追踪(配合 tracing)trace
CompressionLayer响应压缩(gzip, br, zstd)compression-*
CorsLayer跨域资源共享cors
TimeoutLayer请求超时— (来自 tower)
SetRequestHeaderLayer注入请求头set-header
SetResponseHeaderLayer注入响应头set-header
RequestBodyLimitLayer限制请求体大小limit
CatchPanicLayer将 panic 转为 500 响应catch-panic
ValidateRequestHeaderLayer头部验证 / Basic Authvalidate-request
RequestIdLayer请求 ID 生成和传播request-id

错误处理

标准模式:自定义错误类型 + IntoResponse

use thiserror::Error;

#[derive(Error, Debug)]
enum AppError {
    #[error("not found")]
    NotFound,

    #[error("unauthorized")]
    Unauthorized,

    #[error("database error: {0}")]
    Database(#[from] sqlx::Error),

    #[error("validation: {0}")]
    Validation(String),
}

impl IntoResponse for AppError {
    fn into_response(self) -> Response {
        let (status, message) = match &self {
            AppError::NotFound => (StatusCode::NOT_FOUND, self.to_string()),
            AppError::Unauthorized => (StatusCode::UNAUTHORIZED, self.to_string()),
            AppError::Database(e) => {
                tracing::error!(%e, "database error");
                (StatusCode::INTERNAL_SERVER_ERROR, "internal error".into())
            }
            AppError::Validation(msg) => (StatusCode::BAD_REQUEST, msg.clone()),
        };
        (status, Json(json!({ "error": message }))).into_response()
    }
}

// handler 中用 ? 自然传播错误
async fn get_user(
    Path(id): Path<i64>,
    State(db): State<PgPool>,
) -> Result<Json<User>, AppError> {
    let user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", id)
        .fetch_optional(&db)
        .await?
        .ok_or(AppError::NotFound)?;
    Ok(Json(user))
}

关键原则

陷阱

1. Extractor 顺序错误

body 消耗者(Json, String, Bytes, Form)不在最后一个参数 → 编译错误。错误信息是 Handler trait bound 失败,不会直接说「顺序不对」。

修复:body 消耗者永远放最后。

2. State 类型不匹配

handler 提取 State<AppState> 但 router 用了不同的 state 类型 → 编译错误,但信息是 FromRefFromRequestParts 相关的多行 trait bound 失败。

修复:确保 Router<S>State<S>S 一致。

3. Handler trait bound 错误

Handler<_, _> is not implemented for fn(...) 的常见原因:

修复:加 #[axum::debug_handler](需要 macros feature)获取针对性诊断。release 构建中无开销。

4. tower feature flag 缺失

tower 和 tower-http 用细粒度 feature flag。忘了开 tower-http/compressiontower/timeoutLayer trait not implemented 错误,但类型确实存在。

修复:检查 Cargo.toml 中的 feature 矩阵。

5. .layer() 之后添加的路由没有中间件

.layer() 之后的 .route() 不受该中间件影响——静默地绕过了。

6. Clone 要求

State<T> 要求 T: Clone。如果 state 包含不可 Clone 的类型 → 编译错误在 .with_state() 而非 handler 处。

修复:用 Arc 包装不可 Clone 的字段。大多数数据库连接池(sqlx::PgPool, deadpool::Pool)内部已经用了 Arc,直接 Clone 即可。

7. std::sync::Mutex 跨 .await

持有 std::sync::Mutex guard 跨 .await 点 → future 变成 !SendHandler trait bound 失败。

修复:用 tokio::sync::Mutex,或在 .await 前 drop guard。

8. 嵌套 Router 的 fallback 继承

嵌套的 router 继承外层的 fallback,除非它定义了自己的。两个有 fallback 的 router merge 会 panic。

修复:merge 前在一侧调用 .reset_fallback()

生产模式

测试

Router 实现 tower::Service<Request>,可以用 tower::ServiceExt::oneshot 直接测试,不需要启动 HTTP 服务器:

use tower::ServiceExt;

#[tokio::test]
async fn test_get_users() {
    let app = Router::new()
        .route("/users", get(list_users))
        .with_state(test_state());

    let response = app
        .oneshot(
            Request::builder()
                .uri("/users")
                .body(Body::empty())
                .unwrap(),
        )
        .await
        .unwrap();

    assert_eq!(response.status(), StatusCode::OK);
}

.oneshot() 只在 Router<()> 上可用(state 已提供)。

Tracing 集成

use tower_http::trace::TraceLayer;

let app = Router::new()
    .route("/", get(handler))
    .layer(
        TraceLayer::new_for_http()
            .make_span_with(DefaultMakeSpan::new().level(Level::INFO))
            .on_response(DefaultOnResponse::new().level(Level::INFO)),
    );

tracing_subscriber::fmt()
    .with_target(false)
    .with_env_filter("info,tower_http=debug,axum::rejection=trace")
    .init();

axum::rejection=trace 在 trace 级别记录所有提取失败,不需要改代码就能调试提取问题。

优雅关闭

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
axum::serve(listener, app)
    .with_graceful_shutdown(async {
        tokio::signal::ctrl_c().await.expect("failed to install handler");
    })
    .await?;

分层架构

Handler 层    ← HTTP 细节:提取器、响应、路由
Service 层    ← 业务逻辑:用例、验证、编排
Repository 层 ← 数据访问:数据库查询、外部 API

分享这篇文章:

上一篇
可观测性(二):Rust 可观测性体系的架构理解
下一篇
可观测性(一):三大支柱-Logs、Traces、Metrics