use axum::{http::Method, Router}; use config::config; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::fs::File; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer}; use tower_http::{ cors::{Any, CorsLayer}, trace::{self, TraceLayer}, }; use tracing_subscriber::{filter, layer::SubscriberExt, prelude::*, util::SubscriberInitExt}; mod config; mod datasources; mod routes; mod utils; pub struct AppState { db: PgPool, } #[tokio::main] async fn main() { // setting up configuration let _ = config(); // setting up logging let stdout_log = tracing_subscriber::fmt::layer().pretty(); let file = File::create("debug.log"); let file = match file { Ok(file) => file, Err(error) => panic!("Error: {:?}", error), }; let debug_log = tracing_subscriber::fmt::layer().with_writer(Arc::new(file)); let metrics_layer = filter::LevelFilter::INFO; tracing_subscriber::registry() .with( stdout_log // Add an `INFO` filter to the stdout logging layer .with_filter(filter::LevelFilter::INFO) // Combine the filtered `stdout_log` layer with the // `debug_log` layer, producing a new `Layered` layer. .and_then(debug_log) // Add a filter to *both* layers that rejects spans and // events whose targets start with `metrics`. .with_filter(filter::filter_fn(|metadata| { !metadata.target().starts_with("metrics") })), ) .with( // Add a filter to the metrics label that *only* enables // events whose targets start with `metrics`. metrics_layer.with_filter(filter::filter_fn(|metadata| { metadata.target().starts_with("metrics") })), ) .init(); let cors = CorsLayer::new() .allow_methods(Any) .allow_headers(Any) .allow_origin(Any); // if std::env::var("RUST_ENV").unwrap_or_else(|_| "development".to_string()) != "development" { //println!("we're not in development, starting up the rate limiter"); //let governor_conf = Arc::new( // GovernorConfigBuilder::default() // .per_second(2) // .burst_size(5) // .finish() // .unwrap(), //); // //let governor_limiter = governor_conf.limiter().clone(); //let interval = Duration::from_secs(60); //// a separate background task to clean up //std::thread::spawn(move || loop { // std::thread::sleep(interval); // tracing::info!("rate limiting storage size: {}", governor_limiter.len()); // governor_limiter.retain_recent(); //}); // } // grabbing the database url from our env variables let db_connection_str = std::env::var("DATABASE_URL") .unwrap_or_else(|_| "postgres://postgres:password@localhost".to_string()); // set up connection pool let pool = PgPoolOptions::new() .max_connections(10) .acquire_timeout(Duration::from_secs(3)) .connect(&db_connection_str) .await .expect("Failed to connect to database"); let app_state = AppState { db: pool.clone() }; // build our application with some routes let app = Router::new() .nest("/", routes::root::RootRoute::routes()) .nest("/posts", routes::posts::PostsRoute::routes(&app_state)) .nest( "/comments", routes::comments::CommentsRoute::routes(&app_state), ) .nest( "/authors", routes::authors::AuthorsRoute::routes(&app_state), ) .layer(CorsLayer::permissive()) .layer( TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(tracing::Level::INFO)) .on_response(trace::DefaultOnResponse::new().level(tracing::Level::INFO)), ) .fallback(routes::root::RootRoute::not_found); // .layer(cors); //.layer(GovernorLayer { // config: governor_conf, //}); // run it with hyper let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); } async fn shutdown_signal() { let ctrl_c = async { signal::ctrl_c() .await .expect("Failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("Failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } }