mirror of
https://github.com/LeNei/axum-sqlx-template.git
synced 2026-02-13 22:56:19 +00:00
Base login setup
This commit is contained in:
@@ -9,14 +9,11 @@ async fn ping() -> StatusCode {
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
pub fn routes(api_context: ApiContext) -> Router {
|
||||
pub fn routes() -> Router<ApiContext> {
|
||||
let cors = CorsLayer::new()
|
||||
// allow `GET` and `POST` when accessing the resource
|
||||
.allow_methods([Method::GET, Method::POST])
|
||||
// allow requests from any origin
|
||||
.allow_origin(Any);
|
||||
Router::new()
|
||||
.route("/ping", get(ping))
|
||||
.layer(cors)
|
||||
.with_state(api_context)
|
||||
Router::new().route("/ping", get(ping)).layer(cors)
|
||||
}
|
||||
|
||||
93
src/config/auth.rs
Normal file
93
src/config/auth.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use axum::{
|
||||
Extension, RequestPartsExt,
|
||||
body::Body,
|
||||
extract::{FromRequestParts, Request},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
};
|
||||
use axum_login::{AuthnBackend, UserId};
|
||||
use http::request::Parts;
|
||||
use password_auth::verify_password;
|
||||
use serde::Deserialize;
|
||||
use tokio::task;
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::models::{InertiaError, user::User};
|
||||
|
||||
use super::ApiContext;
|
||||
|
||||
// This allows us to extract the authentication fields from forms. We use this
|
||||
// to authenticate requests with the backend.
|
||||
#[derive(Debug, Clone, Deserialize, TS)]
|
||||
#[ts(export)]
|
||||
pub struct Credentials {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub next: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AuthnBackend for ApiContext {
|
||||
type User = User;
|
||||
type Credentials = Credentials;
|
||||
type Error = InertiaError;
|
||||
|
||||
async fn authenticate(
|
||||
&self,
|
||||
creds: Self::Credentials,
|
||||
) -> Result<Option<Self::User>, Self::Error> {
|
||||
let user = sqlx::query_as!(
|
||||
User,
|
||||
"select * from users where username = $1",
|
||||
creds.username
|
||||
)
|
||||
.fetch_optional(&self.db)
|
||||
.await?;
|
||||
|
||||
// Verifying the password is blocking and potentially slow, so we'll do so via
|
||||
// `spawn_blocking`.
|
||||
task::spawn_blocking(|| {
|
||||
// We're using password-based authentication--this works by comparing our form
|
||||
// input with an argon2 password hash.
|
||||
Ok(user.filter(|user| verify_password(creds.password, &user.password).is_ok()))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
|
||||
let user = sqlx::query_as!(User, "select * from users where id = $1", user_id)
|
||||
.fetch_optional(&self.db)
|
||||
.await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_user(auth_session: AuthSession, mut req: Request<Body>, next: Next) -> Response {
|
||||
match auth_session.user {
|
||||
Some(user) => {
|
||||
req.extensions_mut().insert(user);
|
||||
next.run(req).await
|
||||
}
|
||||
None => Redirect::to("/login").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for User
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Response> {
|
||||
let Extension(user) = parts
|
||||
.extract::<Extension<User>>()
|
||||
.await
|
||||
.map_err(|_| Redirect::to("/login").into_response())?;
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
||||
// We use a type alias for convenience.
|
||||
//
|
||||
// Note that we've supplied our concrete backend here.
|
||||
pub type AuthSession = axum_login::AuthSession<ApiContext>;
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod app;
|
||||
pub mod auth;
|
||||
pub mod database;
|
||||
pub mod inertia;
|
||||
pub mod logging;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod api;
|
||||
pub mod config;
|
||||
pub mod models;
|
||||
pub mod pages;
|
||||
pub mod startup;
|
||||
|
||||
0
src/models/error.rs
Normal file
0
src/models/error.rs
Normal file
23
src/models/mod.rs
Normal file
23
src/models/mod.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
pub mod user;
|
||||
|
||||
use axum::response::{IntoResponse, Redirect, Response};
|
||||
use tokio::task;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum InertiaError {
|
||||
#[error(transparent)]
|
||||
Sqlx(#[from] sqlx::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
TaskJoin(#[from] task::JoinError),
|
||||
|
||||
#[error("Something went wrong")]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl IntoResponse for InertiaError {
|
||||
fn into_response(self) -> Response {
|
||||
tracing::error!("Error: {:?}", self);
|
||||
Redirect::to("/error").into_response()
|
||||
}
|
||||
}
|
||||
74
src/models/user.rs
Normal file
74
src/models/user.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use axum_login::AuthUser;
|
||||
use chrono::{DateTime, Utc};
|
||||
use password_auth::generate_hash;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Error, PgConnection, prelude::FromRow, types::Uuid};
|
||||
use ts_rs::TS;
|
||||
|
||||
use super::InertiaError;
|
||||
|
||||
#[derive(FromRow, Serialize, Clone, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
#[serde(skip)]
|
||||
pub password: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
// Here we've implemented `Debug` manually to avoid accidentally logging the
|
||||
// password hash.
|
||||
impl std::fmt::Debug for User {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("User")
|
||||
.field("id", &self.id)
|
||||
.field("username", &self.username)
|
||||
.field("password", &"[redacted]")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthUser for User {
|
||||
type Id = Uuid;
|
||||
|
||||
fn id(&self) -> Self::Id {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn session_auth_hash(&self) -> &[u8] {
|
||||
self.password.as_bytes() // We use the password hash as the auth
|
||||
// hash--what this means
|
||||
// is when the user changes their password the
|
||||
// auth session becomes invalid.
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct NewUser {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl NewUser {
|
||||
pub async fn register(self, conn: &mut PgConnection) -> Result<User, Error> {
|
||||
let password = generate_hash(&self.password);
|
||||
|
||||
sqlx::query_as!(
|
||||
User,
|
||||
r#"
|
||||
INSERT INTO users (username, password)
|
||||
VALUES ($1, $2)
|
||||
RETURNING *
|
||||
"#,
|
||||
self.username,
|
||||
password,
|
||||
)
|
||||
.fetch_one(conn)
|
||||
.await
|
||||
}
|
||||
}
|
||||
78
src/pages/auth.rs
Normal file
78
src/pages/auth.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::State,
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::get,
|
||||
};
|
||||
use axum_inertia::Inertia;
|
||||
use http::StatusCode;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
config::{
|
||||
ApiContext,
|
||||
auth::{AuthSession, Credentials},
|
||||
},
|
||||
models::{InertiaError, user::NewUser},
|
||||
};
|
||||
|
||||
#[tracing::instrument(name = "Login Page", skip(i))]
|
||||
async fn login_page(i: Inertia) -> impl IntoResponse {
|
||||
i.render("Login", json!({}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "Login attempt", skip(auth_session, creds))]
|
||||
async fn login(mut auth_session: AuthSession, Json(creds): Json<Credentials>) -> impl IntoResponse {
|
||||
let user = match auth_session.authenticate(creds.clone()).await {
|
||||
Ok(Some(user)) => user,
|
||||
Ok(None) => {
|
||||
let mut login_url = "/login".to_string();
|
||||
if let Some(next) = creds.next {
|
||||
login_url = format!("{}?next={}", login_url, next);
|
||||
};
|
||||
|
||||
return Redirect::to(&login_url).into_response();
|
||||
}
|
||||
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
};
|
||||
|
||||
if auth_session.login(&user).await.is_err() {
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
|
||||
if let Some(ref next) = creds.next {
|
||||
Redirect::to(next)
|
||||
} else {
|
||||
Redirect::to("/")
|
||||
}
|
||||
.into_response()
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "Register page", skip(i))]
|
||||
async fn register_page(i: Inertia) -> impl IntoResponse {
|
||||
i.render("Register", json!({}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "Registration attempt", skip(ctx, new_user))]
|
||||
async fn register(
|
||||
State(ctx): State<ApiContext>,
|
||||
Json(new_user): Json<NewUser>,
|
||||
) -> Result<impl IntoResponse, InertiaError> {
|
||||
let mut conn = ctx.db.acquire().await?;
|
||||
new_user.register(&mut conn).await?;
|
||||
Ok(Redirect::to("/login").into_response())
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "Logout", skip(auth_session))]
|
||||
async fn logout(mut auth_session: AuthSession) -> Result<impl IntoResponse, InertiaError> {
|
||||
match auth_session.logout().await {
|
||||
Ok(_) => Ok(Redirect::to("/login").into_response()),
|
||||
Err(_) => Err(InertiaError::Unknown),
|
||||
}
|
||||
}
|
||||
pub fn routes() -> Router<ApiContext> {
|
||||
Router::new()
|
||||
.route("/login", get(login_page).post(login))
|
||||
.route("/register", get(register_page).post(register))
|
||||
.route("/logout", get(logout))
|
||||
}
|
||||
@@ -1,22 +1,47 @@
|
||||
use axum::{Router, response::IntoResponse, routing::get};
|
||||
mod auth;
|
||||
use axum::{Router, middleware, response::IntoResponse, routing::get};
|
||||
use axum_inertia::Inertia;
|
||||
use axum_login::AuthManagerLayerBuilder;
|
||||
use http::Method;
|
||||
use serde_json::json;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer, cookie::time::Duration};
|
||||
|
||||
use crate::config::ApiContext;
|
||||
use crate::{
|
||||
config::{ApiContext, auth::get_user},
|
||||
models::user::User,
|
||||
};
|
||||
|
||||
#[tracing::instrument(name = "Home Page", skip(i))]
|
||||
async fn home(i: Inertia) -> impl IntoResponse {
|
||||
i.render("Home", json!({}))
|
||||
async fn home(i: Inertia, user: User) -> impl IntoResponse {
|
||||
i.render("Home", json!({ "user": user }))
|
||||
}
|
||||
|
||||
pub fn routes(api_context: ApiContext) -> Router {
|
||||
pub fn routes(api_context: ApiContext) -> Router<ApiContext> {
|
||||
let cors = CorsLayer::new()
|
||||
// allow `GET` and `POST` when accessing the resource
|
||||
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]);
|
||||
|
||||
// Create a session store and layer
|
||||
let session_store = MemoryStore::default();
|
||||
let session_layer = SessionManagerLayer::new(session_store)
|
||||
.with_expiry(Expiry::OnInactivity(Duration::days(1)));
|
||||
let auth_layer = AuthManagerLayerBuilder::new(api_context.clone(), session_layer).build();
|
||||
|
||||
Router::new()
|
||||
.merge(auth::routes())
|
||||
.merge(protected_routes())
|
||||
.layer(cors)
|
||||
.layer(auth_layer)
|
||||
}
|
||||
|
||||
pub fn protected_routes() -> Router<ApiContext> {
|
||||
let cors = CorsLayer::new()
|
||||
// allow `GET` and `POST` when accessing the resource
|
||||
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]);
|
||||
|
||||
Router::new()
|
||||
.route("/", get(home))
|
||||
.layer(cors)
|
||||
.with_state(api_context)
|
||||
.route_layer(middleware::from_fn(get_user))
|
||||
}
|
||||
|
||||
@@ -18,7 +18,8 @@ pub async fn build(settings: Settings) -> anyhow::Result<()> {
|
||||
tracing::info!("Creating router...");
|
||||
let mut router = Router::new()
|
||||
.merge(page_routes(api_context.clone()))
|
||||
.nest("/api", api_routes(api_context))
|
||||
.nest("/api", api_routes())
|
||||
.with_state(api_context)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
if !settings.is_dev {
|
||||
|
||||
Reference in New Issue
Block a user