nautilus_testkit/
files.rs1use std::{
17 fs::{File, OpenOptions},
18 io::{BufReader, BufWriter, Read, copy},
19 path::Path,
20};
21
22use reqwest::blocking::Client;
23use ring::digest;
24use serde_json::Value;
25
26pub fn ensure_file_exists_or_download_http(
43 filepath: &Path,
44 url: &str,
45 checksums: Option<&Path>,
46) -> anyhow::Result<()> {
47 if filepath.exists() {
48 println!("File already exists: {filepath:?}");
49
50 if let Some(checksums_file) = checksums {
51 if verify_sha256_checksum(filepath, checksums_file)? {
52 println!("File is valid");
53 return Ok(());
54 } else {
55 let new_checksum = calculate_sha256(filepath)?;
56 println!("Adding checksum for existing file: {new_checksum}");
57 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
58 return Ok(());
59 }
60 }
61 return Ok(());
62 }
63
64 download_file(filepath, url)?;
65
66 if let Some(checksums_file) = checksums {
67 let new_checksum = calculate_sha256(filepath)?;
68 update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
69 }
70
71 Ok(())
72}
73
74fn download_file(filepath: &Path, url: &str) -> anyhow::Result<()> {
75 println!("Downloading file from {url} to {filepath:?}");
76
77 if let Some(parent) = filepath.parent() {
78 std::fs::create_dir_all(parent)?;
79 }
80
81 let mut response = Client::new().get(url).send()?;
82 if !response.status().is_success() {
83 anyhow::bail!("Failed to download file: HTTP {}", response.status());
84 }
85
86 let mut out = File::create(filepath)?;
87 copy(&mut response, &mut out)?;
88
89 println!("File downloaded to {filepath:?}");
90 Ok(())
91}
92
93fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
94 let mut file = File::open(filepath)?;
95 let mut context = digest::Context::new(&digest::SHA256);
96 let mut buffer = [0; 4096];
97
98 loop {
99 let count = file.read(&mut buffer)?;
100 if count == 0 {
101 break;
102 }
103 context.update(&buffer[..count]);
104 }
105
106 let digest = context.finish();
107 Ok(hex::encode(digest.as_ref()))
108}
109
110fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
111 let file = File::open(checksums)?;
112 let reader = BufReader::new(file);
113 let checksums: Value = serde_json::from_reader(reader)?;
114
115 let filename = filepath.file_name().unwrap().to_str().unwrap();
116 if let Some(expected_checksum) = checksums.get(filename) {
117 let expected_checksum_str = expected_checksum.as_str().unwrap();
118 let expected_hash = expected_checksum_str
119 .strip_prefix("sha256:")
120 .unwrap_or(expected_checksum_str);
121 let calculated_checksum = calculate_sha256(filepath)?;
122 if expected_hash == calculated_checksum {
123 return Ok(true);
124 }
125 }
126
127 Ok(false)
128}
129
130fn update_sha256_checksums(
131 filepath: &Path,
132 checksums_file: &Path,
133 new_checksum: &str,
134) -> anyhow::Result<()> {
135 let checksums: Value = if checksums_file.exists() {
136 let file = File::open(checksums_file)?;
137 let reader = BufReader::new(file);
138 serde_json::from_reader(reader)?
139 } else {
140 serde_json::json!({})
141 };
142
143 let mut checksums_map = checksums.as_object().unwrap().clone();
144
145 let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
147 let prefixed_checksum = format!("sha256:{new_checksum}");
148 checksums_map.insert(filename, Value::String(prefixed_checksum));
149
150 let file = OpenOptions::new()
151 .write(true)
152 .create(true)
153 .truncate(true)
154 .open(checksums_file)?;
155 let writer = BufWriter::new(file);
156 serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
157
158 Ok(())
159}
160
161#[cfg(test)]
165mod tests {
166 use std::{
167 fs,
168 io::{BufWriter, Write},
169 net::SocketAddr,
170 sync::Arc,
171 };
172
173 use axum::{Router, http::StatusCode, routing::get, serve};
174 use rstest::*;
175 use serde_json::{json, to_writer};
176 use tempfile::TempDir;
177 use tokio::{
178 net::TcpListener,
179 task,
180 time::{Duration, sleep},
181 };
182
183 use super::*;
184
185 async fn setup_test_server(
186 server_content: Option<String>,
187 status_code: StatusCode,
188 ) -> SocketAddr {
189 let server_content = Arc::new(server_content);
190 let server_content_clone = server_content.clone();
191 let app = Router::new().route(
192 "/testfile.txt",
193 get(move || {
194 let server_content = server_content_clone.clone();
195 async move {
196 let response_body = match &*server_content {
197 Some(content) => content.clone(),
198 None => "File not found".to_string(),
199 };
200 (status_code, response_body)
201 }
202 }),
203 );
204
205 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
206 let addr = listener.local_addr().unwrap();
207 let server = serve(listener, app);
208
209 task::spawn(async move {
210 if let Err(e) = server.await {
211 eprintln!("server error: {e}");
212 }
213 });
214
215 sleep(Duration::from_millis(100)).await;
216
217 addr
218 }
219
220 #[tokio::test]
221 async fn test_file_already_exists() {
222 let temp_dir = TempDir::new().unwrap();
223 let file_path = temp_dir.path().join("testfile.txt");
224 fs::write(&file_path, "Existing file content").unwrap();
225
226 let url = "http://example.com/testfile.txt".to_string();
227 let result = ensure_file_exists_or_download_http(&file_path, &url, None);
228
229 assert!(result.is_ok());
230 let content = fs::read_to_string(&file_path).unwrap();
231 assert_eq!(content, "Existing file content");
232 }
233
234 #[tokio::test]
235 async fn test_download_file_success() {
236 let temp_dir = TempDir::new().unwrap();
237 let filepath = temp_dir.path().join("testfile.txt");
238 let filepath_clone = filepath.clone();
239
240 let server_content = Some("Server file content".to_string());
241 let status_code = StatusCode::OK;
242 let addr = setup_test_server(server_content.clone(), status_code).await;
243 let url = format!("http://{addr}/testfile.txt");
244
245 let result = tokio::task::spawn_blocking(move || {
246 ensure_file_exists_or_download_http(&filepath_clone, &url, None)
247 })
248 .await
249 .unwrap();
250
251 assert!(result.is_ok());
252 let content = fs::read_to_string(&filepath).unwrap();
253 assert_eq!(content, server_content.unwrap());
254 }
255
256 #[tokio::test]
257 async fn test_download_file_not_found() {
258 let temp_dir = TempDir::new().unwrap();
259 let file_path = temp_dir.path().join("testfile.txt");
260
261 let server_content = None;
262 let status_code = StatusCode::NOT_FOUND;
263 let addr = setup_test_server(server_content, status_code).await;
264 let url = format!("http://{addr}/testfile.txt");
265
266 let result = tokio::task::spawn_blocking(move || {
267 ensure_file_exists_or_download_http(&file_path, &url, None)
268 })
269 .await
270 .unwrap();
271
272 assert!(result.is_err());
273 let err_msg = format!("{}", result.unwrap_err());
274 assert!(
275 err_msg.contains("Failed to download file"),
276 "Unexpected error message: {err_msg}"
277 );
278 }
279
280 #[tokio::test]
281 async fn test_network_error() {
282 let temp_dir = TempDir::new().unwrap();
283 let file_path = temp_dir.path().join("testfile.txt");
284
285 let url = "http://127.0.0.1:0/testfile.txt".to_string();
287
288 let result = tokio::task::spawn_blocking(move || {
289 ensure_file_exists_or_download_http(&file_path, &url, None)
290 })
291 .await
292 .unwrap();
293
294 assert!(result.is_err());
295 let err_msg = format!("{}", result.unwrap_err());
296 assert!(
297 err_msg.contains("error"),
298 "Unexpected error message: {err_msg}"
299 );
300 }
301
302 #[rstest]
303 fn test_calculate_sha256() -> anyhow::Result<()> {
304 let temp_dir = TempDir::new()?;
305 let test_file_path = temp_dir.path().join("test_file.txt");
306 let mut test_file = File::create(&test_file_path)?;
307 let content = b"Hello, world!";
308 test_file.write_all(content)?;
309
310 let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
311 let calculated_hash = calculate_sha256(&test_file_path)?;
312
313 assert_eq!(calculated_hash, expected_hash);
314 Ok(())
315 }
316
317 #[rstest]
318 fn test_verify_sha256_checksum() -> anyhow::Result<()> {
319 let temp_dir = TempDir::new()?;
320 let test_file_path = temp_dir.path().join("test_file.txt");
321 let mut test_file = File::create(&test_file_path)?;
322 let content = b"Hello, world!";
323 test_file.write_all(content)?;
324
325 let calculated_checksum = calculate_sha256(&test_file_path)?;
326
327 let checksums_path = temp_dir.path().join("checksums.json");
329 let checksums_data = json!({
330 "test_file.txt": format!("sha256:{}", calculated_checksum)
331 });
332 let checksums_file = File::create(&checksums_path)?;
333 let writer = BufWriter::new(checksums_file);
334 to_writer(writer, &checksums_data)?;
335
336 let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
337 assert!(is_valid, "The checksum should be valid");
338 Ok(())
339 }
340}