scufflecloud_geo_ip/
middleware.rs

1use std::future::{Ready, ready};
2use std::marker::PhantomData;
3use std::net::{IpAddr, SocketAddr};
4use std::sync::Arc;
5
6use axum::extract::ConnectInfo;
7use axum::http::{self, HeaderValue, Request, StatusCode};
8use axum::response::IntoResponse;
9use futures::TryFutureExt;
10use futures::future::{Either, MapOk};
11use maxminddb::MaxMindDbError;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::GeoIpInterface;
16
17#[derive(Debug, Clone, Copy)]
18pub struct IpAddressInfo {
19    pub ip_address: IpAddr,
20}
21
22impl IpAddressInfo {
23    pub fn to_network(self) -> ipnetwork::IpNetwork {
24        self.ip_address.into()
25    }
26
27    pub fn lookup_geoip_info<'a, T: serde::Deserialize<'a>>(
28        &self,
29        global: &'a impl crate::GeoIpInterface,
30    ) -> Result<Option<T>, MaxMindDbError> {
31        global.geo_ip_resolver().lookup::<T>(self.ip_address)
32    }
33}
34
35#[derive(thiserror::Error, Debug)]
36enum ParseIpHeaderError {
37    #[error("header value is not valid ASCII: {0}")]
38    ValueNotAscii(#[from] http::header::ToStrError),
39    #[error("header contains invalid IP address: {0}")]
40    InvalidIp(#[from] std::net::AddrParseError),
41}
42
43fn parse_ip_header(value: &HeaderValue) -> Result<Vec<IpAddr>, ParseIpHeaderError> {
44    let s = value.to_str()?;
45    let ips = s
46        .split(',')
47        .map(|part| part.trim().parse::<IpAddr>().map(|ip| ip.to_canonical()))
48        .collect::<Result<Vec<_>, _>>()?;
49    Ok(ips)
50}
51
52pub fn middleware<G>() -> GeoIpLayer<G> {
53    GeoIpLayer { _marker: PhantomData }
54}
55
56pub struct GeoIpLayer<G> {
57    _marker: PhantomData<G>,
58}
59
60impl<G> Clone for GeoIpLayer<G> {
61    fn clone(&self) -> Self {
62        Self { _marker: self._marker }
63    }
64}
65
66pub struct GeoIpService<G, S> {
67    inner: S,
68    _marker: PhantomData<G>,
69}
70
71impl<G, S: Clone> Clone for GeoIpService<G, S> {
72    fn clone(&self) -> Self {
73        GeoIpService {
74            inner: self.inner.clone(),
75            _marker: self._marker,
76        }
77    }
78}
79
80impl<S, G> Layer<S> for GeoIpLayer<G> {
81    type Service = GeoIpService<G, S>;
82
83    fn layer(&self, inner: S) -> Self::Service {
84        GeoIpService {
85            inner,
86            _marker: self._marker,
87        }
88    }
89}
90
91macro_rules! try_ret {
92    ($result:expr) => {
93        match $result {
94            Ok(r) => r,
95            Err(err) => ret_err!(err),
96        }
97    };
98}
99
100macro_rules! ret_err {
101    ($err:expr) => {
102        return Either::Right(ready(Ok($err.into_response())))
103    };
104}
105
106impl<S, B, G> Service<Request<B>> for GeoIpService<G, S>
107where
108    S: Service<Request<B>>,
109    S::Response: IntoResponse,
110    G: GeoIpInterface + Send + Sync + 'static,
111{
112    type Error = S::Error;
113    type Future = Either<MapOk<S::Future, fn(S::Response) -> Self::Response>, Ready<Result<Self::Response, S::Error>>>;
114    type Response = axum::response::Response;
115
116    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
117        self.inner.poll_ready(cx)
118    }
119
120    fn call(&mut self, mut req: Request<B>) -> Self::Future {
121        let Some(info) = req.extensions().get::<ConnectInfo<SocketAddr>>() else {
122            tracing::error!("failed to get connection info");
123            ret_err!(StatusCode::INTERNAL_SERVER_ERROR);
124        };
125
126        let mut ip_address = info.ip().to_canonical();
127        let Some(global) = req.extensions().get::<Arc<G>>() else {
128            tracing::error!("request missing global");
129            ret_err!(StatusCode::INTERNAL_SERVER_ERROR);
130        };
131
132        if let Some(reverse_proxy_config) = global.reverse_proxy_config()
133            && !reverse_proxy_config
134                .internal_networks
135                .iter()
136                .any(|net| net.contains(ip_address))
137        {
138            if !reverse_proxy_config
139                .trusted_proxies
140                .iter()
141                .any(|net| net.contains(ip_address))
142            {
143                tracing::error!(ip = %ip_address, "untrusted ip address, connecting ip not in trusted proxies");
144                ret_err!(StatusCode::BAD_REQUEST);
145            }
146
147            let ip_header = try_ret!(req.headers().get(reverse_proxy_config.ip_header.as_ref()).ok_or_else(|| {
148                tracing::error!(headers = ?req.headers(), header = reverse_proxy_config.ip_header.as_ref(), "missing IP header");
149                StatusCode::BAD_REQUEST
150            }));
151            let ips = try_ret!(parse_ip_header(ip_header).map_err(|e| {
152                tracing::error!(err = %e, header = reverse_proxy_config.ip_header.as_ref(), "invalid IP header");
153                StatusCode::BAD_REQUEST
154            }));
155
156            for ip in ips.iter().rev() {
157                if !reverse_proxy_config.trusted_proxies.iter().any(|net| net.contains(*ip)) {
158                    // Found the client IP
159                    ip_address = *ip;
160                    break;
161                }
162            }
163
164            req.extensions_mut().insert(IpAddressInfo { ip_address });
165        }
166
167        Either::Left(self.inner.call(req).map_ok(|resp| resp.into_response()))
168    }
169}