use axum::Router; use config::config; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::fs::File; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer}; use tower_http::trace::{self, TraceLayer}; use tracing_subscriber::{filter, layer::SubscriberExt, prelude::*, util::SubscriberInitExt}; mod config; mod datasources; mod routes; pub struct AppState { db: PgPool, } #[tokio::main] async fn main() { // setting up configuration let _ = config(); // setting up logging let stdout_log = tracing_subscriber::fmt::layer().pretty(); let file = File::create("debug.log"); let file = match file { Ok(file) => file, Err(error) => panic!("Error: {:?}", error), }; let debug_log = tracing_subscriber::fmt::layer().with_writer(Arc::new(file)); let metrics_layer = filter::LevelFilter::INFO; tracing_subscriber::registry() .with( stdout_log // Add an `INFO` filter to the stdout logging layer .with_filter(filter::LevelFilter::INFO) // Combine the filtered `stdout_log` layer with the // `debug_log` layer, producing a new `Layered` layer. .and_then(debug_log) // Add a filter to *both* layers that rejects spans and // events whose targets start with `metrics`. .with_filter(filter::filter_fn(|metadata| { !metadata.target().starts_with("metrics") })), ) .with( // Add a filter to the metrics label that *only* enables // events whose targets start with `metrics`. metrics_layer.with_filter(filter::filter_fn(|metadata| { metadata.target().starts_with("metrics") })), ) .init(); if std::env::var("RUST_ENV").expect("development") != "development" { println!("we're not in development, starting up the rate limiter"); let governor_conf = Arc::new( GovernorConfigBuilder::default() .per_second(2) .burst_size(5) .finish() .unwrap(), ); let governor_limiter = governor_conf.limiter().clone(); let interval = Duration::from_secs(60); // a separate background task to clean up std::thread::spawn(move || loop { std::thread::sleep(interval); tracing::info!("rate limiting storage size: {}", governor_limiter.len()); governor_limiter.retain_recent(); }); } // grabbing the database url from our env variables let db_connection_str = std::env::var("DATABASE_URL") .unwrap_or_else(|_| "postgres://postgres:password@localhost".to_string()); // set up connection pool let pool = PgPoolOptions::new() .max_connections(10) .acquire_timeout(Duration::from_secs(3)) .connect(&db_connection_str) .await .expect("Failed to connect to database"); let app_state = AppState { db: pool.clone() }; // build our application with some routes let app = Router::new() .nest("/", routes::root::RootRoute::routes()) .nest("/posts", routes::posts::PostsRoute::routes(&app_state)) .layer( TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(tracing::Level::INFO)) .on_response(trace::DefaultOnResponse::new().level(tracing::Level::INFO)), ); // .nest( // "/comments", // routes::comments::CommentsRoute::routes(&app_state), // ); // run it with hyper let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); } async fn shutdown_signal() { let ctrl_c = async { signal::ctrl_c() .await .expect("Failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("Failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } }