scufflecloud_geo_ip/
middleware.rs1use std::future::{Ready, ready};
2use std::marker::PhantomData;
3use std::net::{IpAddr, SocketAddr};
4use std::sync::Arc;
5
6use axum::extract::ConnectInfo;
7use axum::http::{self, HeaderValue, Request, StatusCode};
8use axum::response::IntoResponse;
9use futures::TryFutureExt;
10use futures::future::{Either, MapOk};
11use maxminddb::MaxMindDbError;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::GeoIpInterface;
16
17#[derive(Debug, Clone, Copy)]
18pub struct IpAddressInfo {
19 pub ip_address: IpAddr,
20}
21
22impl IpAddressInfo {
23 pub fn to_network(self) -> ipnetwork::IpNetwork {
24 self.ip_address.into()
25 }
26
27 pub fn lookup_geoip_info<'a, T: serde::Deserialize<'a>>(
28 &self,
29 global: &'a impl crate::GeoIpInterface,
30 ) -> Result<Option<T>, MaxMindDbError> {
31 global.geo_ip_resolver().lookup::<T>(self.ip_address)
32 }
33}
34
35#[derive(thiserror::Error, Debug)]
36enum ParseIpHeaderError {
37 #[error("header value is not valid ASCII: {0}")]
38 ValueNotAscii(#[from] http::header::ToStrError),
39 #[error("header contains invalid IP address: {0}")]
40 InvalidIp(#[from] std::net::AddrParseError),
41}
42
43fn parse_ip_header(value: &HeaderValue) -> Result<Vec<IpAddr>, ParseIpHeaderError> {
44 let s = value.to_str()?;
45 let ips = s
46 .split(',')
47 .map(|part| part.trim().parse::<IpAddr>().map(|ip| ip.to_canonical()))
48 .collect::<Result<Vec<_>, _>>()?;
49 Ok(ips)
50}
51
52pub fn middleware<G>() -> GeoIpLayer<G> {
53 GeoIpLayer { _marker: PhantomData }
54}
55
56pub struct GeoIpLayer<G> {
57 _marker: PhantomData<G>,
58}
59
60impl<G> Clone for GeoIpLayer<G> {
61 fn clone(&self) -> Self {
62 Self { _marker: self._marker }
63 }
64}
65
66pub struct GeoIpService<G, S> {
67 inner: S,
68 _marker: PhantomData<G>,
69}
70
71impl<G, S: Clone> Clone for GeoIpService<G, S> {
72 fn clone(&self) -> Self {
73 GeoIpService {
74 inner: self.inner.clone(),
75 _marker: self._marker,
76 }
77 }
78}
79
80impl<S, G> Layer<S> for GeoIpLayer<G> {
81 type Service = GeoIpService<G, S>;
82
83 fn layer(&self, inner: S) -> Self::Service {
84 GeoIpService {
85 inner,
86 _marker: self._marker,
87 }
88 }
89}
90
91macro_rules! try_ret {
92 ($result:expr) => {
93 match $result {
94 Ok(r) => r,
95 Err(err) => ret_err!(err),
96 }
97 };
98}
99
100macro_rules! ret_err {
101 ($err:expr) => {
102 return Either::Right(ready(Ok($err.into_response())))
103 };
104}
105
106impl<S, B, G> Service<Request<B>> for GeoIpService<G, S>
107where
108 S: Service<Request<B>>,
109 S::Response: IntoResponse,
110 G: GeoIpInterface + Send + Sync + 'static,
111{
112 type Error = S::Error;
113 type Future = Either<MapOk<S::Future, fn(S::Response) -> Self::Response>, Ready<Result<Self::Response, S::Error>>>;
114 type Response = axum::response::Response;
115
116 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
117 self.inner.poll_ready(cx)
118 }
119
120 fn call(&mut self, mut req: Request<B>) -> Self::Future {
121 let Some(info) = req.extensions().get::<ConnectInfo<SocketAddr>>() else {
122 tracing::error!("failed to get connection info");
123 ret_err!(StatusCode::INTERNAL_SERVER_ERROR);
124 };
125
126 let mut ip_address = info.ip().to_canonical();
127 let Some(global) = req.extensions().get::<Arc<G>>() else {
128 tracing::error!("request missing global");
129 ret_err!(StatusCode::INTERNAL_SERVER_ERROR);
130 };
131
132 if let Some(reverse_proxy_config) = global.reverse_proxy_config()
133 && !reverse_proxy_config
134 .internal_networks
135 .iter()
136 .any(|net| net.contains(ip_address))
137 {
138 if !reverse_proxy_config
139 .trusted_proxies
140 .iter()
141 .any(|net| net.contains(ip_address))
142 {
143 tracing::error!(ip = %ip_address, "untrusted ip address, connecting ip not in trusted proxies");
144 ret_err!(StatusCode::BAD_REQUEST);
145 }
146
147 let ip_header = try_ret!(req.headers().get(reverse_proxy_config.ip_header.as_ref()).ok_or_else(|| {
148 tracing::error!(headers = ?req.headers(), header = reverse_proxy_config.ip_header.as_ref(), "missing IP header");
149 StatusCode::BAD_REQUEST
150 }));
151 let ips = try_ret!(parse_ip_header(ip_header).map_err(|e| {
152 tracing::error!(err = %e, header = reverse_proxy_config.ip_header.as_ref(), "invalid IP header");
153 StatusCode::BAD_REQUEST
154 }));
155
156 for ip in ips.iter().rev() {
157 if !reverse_proxy_config.trusted_proxies.iter().any(|net| net.contains(*ip)) {
158 ip_address = *ip;
160 break;
161 }
162 }
163
164 req.extensions_mut().insert(IpAddressInfo { ip_address });
165 }
166
167 Either::Left(self.inner.call(req).map_ok(|resp| resp.into_response()))
168 }
169}