跳转到正文
zeno's blog
返回

Rust Web:Axum HTTP Server 指南

专题: Rust Web

Table of contents

Open Table of contents

Cargo.toml

[dependencies]
axum = "0.8"
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower-http = { version = "0.6", features = ["cors", "trace", "compression-full"] }
tower = { version = "0.5", features = ["timeout"] }
tracing = "0.1"
tracing-subscriber = "0.3"

1. 最小 Server

use axum::{routing::get, Router};

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/", get(|| async { "Hello, World!" }));

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

2. 路由

use axum::{routing::{get, post, put, delete}, Router};

let app = Router::new()
    .route("/", get(root))
    .route("/users", get(list_users).post(create_user))
    .route("/users/{id}", get(get_user).put(update_user).delete(delete_user));

嵌套路由

fn api_routes() -> Router {
    Router::new()
        .route("/users", get(list_users))
        .route("/users/{id}", get(get_user))
        .route("/items", get(list_items))
}

let app = Router::new()
    .route("/health", get(health_check))
    .nest("/api/v1", api_routes());
// 结果: /health, /api/v1/users, /api/v1/users/{id}, /api/v1/items

3. Handler 与 Extractor

axum 的核心设计:handler 就是普通 async 函数,参数通过 extractor 自动从请求中提取。

Path 参数

use axum::extract::Path;

async fn get_user(Path(id): Path<u32>) -> String {
    format!("User {}", id)
}

// 多个路径参数
async fn get_item(Path((user_id, item_id)): Path<(u32, u32)>) -> String {
    format!("User {} Item {}", user_id, item_id)
}
// 路由: /users/{user_id}/items/{item_id}

Query 参数

use axum::extract::Query;
use serde::Deserialize;

#[derive(Deserialize)]
struct Pagination {
    page: Option<u32>,
    per_page: Option<u32>,
}

async fn list_users(Query(params): Query<Pagination>) -> String {
    let page = params.page.unwrap_or(1);
    let per_page = params.per_page.unwrap_or(20);
    format!("Page {} ({} per page)", page, per_page)
}
// GET /users?page=2&per_page=10

JSON Body

use axum::{Json, http::StatusCode};
use serde::{Deserialize, Serialize};

#[derive(Deserialize)]
struct CreateUser {
    username: String,
    email: String,
}

#[derive(Serialize)]
struct User {
    id: u64,
    username: String,
    email: String,
}

async fn create_user(Json(payload): Json<CreateUser>) -> (StatusCode, Json<User>) {
    let user = User {
        id: 1,
        username: payload.username,
        email: payload.email,
    };
    (StatusCode::CREATED, Json(user))
}

多个 Extractor 组合

use axum::extract::{Path, Query, State, Json};

async fn update_user(
    State(db): State<DbPool>,
    Path(id): Path<u32>,
    Json(body): Json<UpdateUser>,
) -> Result<Json<User>, StatusCode> {
    // ...
}

规则:消费 body 的 extractor(Json, String, Bytes)必须放在参数最后一个。

Header 提取

use axum::http::HeaderMap;

async fn handler(headers: HeaderMap) -> String {
    let ua = headers
        .get("user-agent")
        .and_then(|v| v.to_str().ok())
        .unwrap_or("unknown");
    format!("UA: {}", ua)
}

4. 共享状态(State)

use axum::extract::State;
use std::sync::Arc;

#[derive(Clone)]
struct AppState {
    db: Arc<DbPool>,
    config: AppConfig,
}

async fn handler(State(state): State<AppState>) -> String {
    format!("Connected to DB, API key: {}", state.config.api_key)
}

#[tokio::main]
async fn main() {
    let state = AppState {
        db: Arc::new(DbPool::new()),
        config: AppConfig { api_key: "secret".into() },
    };

    let app = Router::new()
        .route("/", get(handler))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

5. 响应(IntoResponse)

任何实现了 IntoResponse 的类型都能作为 handler 返回值。

use axum::{http::StatusCode, response::{Json, IntoResponse}};

// 字符串
async fn text() -> &'static str { "hello" }

// 状态码 + JSON
async fn json() -> (StatusCode, Json<serde_json::Value>) {
    (StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
}

// 自定义响应
async fn custom() -> impl IntoResponse {
    (
        StatusCode::OK,
        [("X-Custom-Header", "value")],
        "body content",
    )
}

6. 错误处理

use axum::{http::StatusCode, response::{IntoResponse, Json}};

enum AppError {
    NotFound(String),
    BadRequest(String),
    Internal(anyhow::Error),
}

impl IntoResponse for AppError {
    fn into_response(self) -> axum::response::Response {
        let (status, message) = match self {
            AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
            AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
            AppError::Internal(err) => (
                StatusCode::INTERNAL_SERVER_ERROR,
                format!("Internal error: {}", err),
            ),
        };
        (status, Json(serde_json::json!({"error": message}))).into_response()
    }
}

async fn get_user(Path(id): Path<u32>) -> Result<Json<User>, AppError> {
    if id == 0 {
        return Err(AppError::BadRequest("ID cannot be 0".into()));
    }
    let user = find_user(id).ok_or(AppError::NotFound(format!("User {} not found", id)))?;
    Ok(Json(user))
}

Extractor 错误处理

use axum::extract::rejection::JsonRejection;

async fn create_user(payload: Result<Json<CreateUser>, JsonRejection>) -> impl IntoResponse {
    match payload {
        Ok(Json(user)) => (StatusCode::CREATED, Json(serde_json::json!({"ok": true}))),
        Err(err) => (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": err.to_string()}))),
    }
}

7. 中间件(Tower Layer)

axum 直接使用 tower 生态的中间件。

use tower_http::{
    cors::{CorsLayer, Any},
    trace::TraceLayer,
    compression::CompressionLayer,
};
use tower::{ServiceBuilder, timeout::TimeoutLayer};
use std::time::Duration;
use axum::error_handling::HandleErrorLayer;

let app = Router::new()
    .route("/", get(handler))
    .layer(
        ServiceBuilder::new()
            .layer(HandleErrorLayer::new(|_: axum::BoxError| async {
                StatusCode::REQUEST_TIMEOUT
            }))
            .layer(TimeoutLayer::new(Duration::from_secs(30)))
    )
    .layer(CompressionLayer::new())
    .layer(CorsLayer::new().allow_origin(Any).allow_methods(Any))
    .layer(TraceLayer::new_for_http());

注意:layer 的执行顺序是从下往上(后注册的先执行)。

8. WebSocket

# Cargo.toml
axum = { version = "0.8", features = ["ws"] }
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::response::IntoResponse;

async fn ws_handler(ws: WebSocketUpgrade) -> impl IntoResponse {
    ws.on_upgrade(handle_socket)
}

async fn handle_socket(mut socket: WebSocket) {
    while let Some(Ok(msg)) = socket.recv().await {
        match msg {
            Message::Text(text) => {
                let reply = format!("Echo: {}", text);
                if socket.send(Message::Text(reply.into())).await.is_err() {
                    return;
                }
            }
            Message::Close(_) => return,
            _ => {}
        }
    }
}

// 路由
let app = Router::new().route("/ws", axum::routing::any(ws_handler));

9. 完整示例

use axum::{
    extract::{Path, Query, State, Json},
    http::StatusCode,
    response::IntoResponse,
    routing::{get, post},
    Router,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tower_http::trace::TraceLayer;

#[derive(Clone)]
struct AppState {
    db: Arc<DbPool>,
}

#[derive(Deserialize)]
struct ListParams {
    page: Option<u32>,
    limit: Option<u32>,
}

#[derive(Serialize)]
struct Player {
    id: u32,
    name: String,
    level: u32,
}

#[derive(Deserialize)]
struct CreatePlayer {
    name: String,
}

async fn list_players(
    State(state): State<AppState>,
    Query(params): Query<ListParams>,
) -> Json<Vec<Player>> {
    // ...
    Json(vec![])
}

async fn get_player(
    State(state): State<AppState>,
    Path(id): Path<u32>,
) -> Result<Json<Player>, StatusCode> {
    // ...
    Err(StatusCode::NOT_FOUND)
}

async fn create_player(
    State(state): State<AppState>,
    Json(body): Json<CreatePlayer>,
) -> (StatusCode, Json<Player>) {
    let player = Player { id: 1, name: body.name, level: 1 };
    (StatusCode::CREATED, Json(player))
}

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt::init();

    let state = AppState { db: Arc::new(DbPool) };

    let app = Router::new()
        .route("/players", get(list_players).post(create_player))
        .route("/players/{id}", get(get_player))
        .layer(TraceLayer::new_for_http())
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    println!("Listening on 0.0.0.0:3000");
    axum::serve(listener, app).await.unwrap();
}

10. 与 actix-web 对比

axumactix-web
异步运行时tokioactix-rt(基于 tokio)
路由函数式,无宏#[get] #[post]
中间件tower 生态(通用)自有 trait(专用)
状态共享State extractorweb::Data(内部 Arc)
设计哲学组合优于继承,一切皆 trait功能全面,开箱即用
生态集成tonic (gRPC) 同家族,无缝共存独立生态

分享这篇文章:

上一篇
Rust Web:Axum 返回值-IntoResponse 与错误处理
下一篇
Rust Web:Axum 优雅停机(Graceful Shutdown)