diff --git a/backend/public/src/main.rs b/backend/public/src/main.rs index d5e331f..8473f1d 100644 --- a/backend/public/src/main.rs +++ b/backend/public/src/main.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; +use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer}; use tower_http::trace::{self, TraceLayer}; use tracing_subscriber::{filter, layer::SubscriberExt, prelude::*, util::SubscriberInitExt}; @@ -54,18 +55,25 @@ async fn main() { ) .init(); - // tracing_subscriber::registry() - // .with( - // tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { - // format!( - // "{}=debug,tower_http=debug,axum=trace", - // env!("CARGO_CRATE_NAME") - // ) - // .into() - // }), - // ) - // .with(tracing_subscriber::fmt::layer()) - // .init(); + if std::env::var("RUST_ENV").expect("development") != "development" { + println!("we're not in development, starting up the rate limiter"); + let governor_conf = Arc::new( + GovernorConfigBuilder::default() + .per_second(2) + .burst_size(5) + .finish() + .unwrap(), + ); + + let governor_limiter = governor_conf.limiter().clone(); + let interval = Duration::from_secs(60); + // a separate background task to clean up + std::thread::spawn(move || loop { + std::thread::sleep(interval); + tracing::info!("rate limiting storage size: {}", governor_limiter.len()); + governor_limiter.retain_recent(); + }); + } // grabbing the database url from our env variables let db_connection_str = std::env::var("DATABASE_URL")