nautilus_infrastructure/sql/
pg.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use derive_builder::Builder;
17use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
18
19#[derive(Debug, Clone, Builder)]
20#[builder(default)]
21pub struct PostgresConnectOptions {
22    pub host: String,
23    pub port: u16,
24    pub username: String,
25    pub password: String,
26    pub database: String,
27}
28
29impl PostgresConnectOptions {
30    /// Creates a new [`PostgresConnectOptions`] instance.
31    #[must_use]
32    pub const fn new(
33        host: String,
34        port: u16,
35        username: String,
36        password: String,
37        database: String,
38    ) -> Self {
39        Self {
40            host,
41            port,
42            username,
43            password,
44            database,
45        }
46    }
47
48    #[must_use]
49    pub fn connection_string(&self) -> String {
50        format!(
51            "postgres://{username}:{password}@{host}:{port}/{database}",
52            username = self.username,
53            password = self.password,
54            host = self.host,
55            port = self.port,
56            database = self.database
57        )
58    }
59
60    #[must_use]
61    pub fn default_administrator() -> Self {
62        Self::new(
63            String::from("localhost"),
64            5432,
65            String::from("postgres"),
66            String::from("pass"),
67            String::from("nautilus"),
68        )
69    }
70}
71
72impl Default for PostgresConnectOptions {
73    fn default() -> Self {
74        Self::new(
75            String::from("localhost"),
76            5432,
77            String::from("nautilus"),
78            String::from("pass"),
79            String::from("nautilus"),
80        )
81    }
82}
83
84impl From<PostgresConnectOptions> for PgConnectOptions {
85    fn from(opt: PostgresConnectOptions) -> Self {
86        Self::new()
87            .host(opt.host.as_str())
88            .port(opt.port)
89            .username(opt.username.as_str())
90            .password(opt.password.as_str())
91            .database(opt.database.as_str())
92            .disable_statement_logging()
93    }
94}
95
96/// Constructs `PostgresConnectOptions` by merging provided arguments, environment variables, and defaults.
97///
98/// # Panics
99///
100/// Panics if an environment variable for port cannot be parsed into a `u16`.
101#[must_use]
102pub fn get_postgres_connect_options(
103    host: Option<String>,
104    port: Option<u16>,
105    username: Option<String>,
106    password: Option<String>,
107    database: Option<String>,
108) -> PostgresConnectOptions {
109    let defaults = PostgresConnectOptions::default_administrator();
110    let host = host
111        .or_else(|| std::env::var("POSTGRES_HOST").ok())
112        .unwrap_or(defaults.host);
113    let port = port
114        .or_else(|| {
115            std::env::var("POSTGRES_PORT")
116                .map(|port| port.parse::<u16>().unwrap())
117                .ok()
118        })
119        .unwrap_or(defaults.port);
120    let username = username
121        .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
122        .unwrap_or(defaults.username);
123    let database = database
124        .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
125        .unwrap_or(defaults.database);
126    let password = password
127        .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
128        .unwrap_or(defaults.password);
129    PostgresConnectOptions::new(host, port, username, password, database)
130}
131
132/// Connects to a Postgres database with the provided connection `options` returning a connection pool.
133///
134/// # Errors
135///
136/// Returns an error if establishing the database connection fails.
137pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
138    Ok(PgPool::connect_with(options).await?)
139}
140
141/// Scans the current working directory for the `posei_trader` repository
142/// and constructs the path to the SQL schema directory.
143///
144/// # Errors
145///
146/// Returns an error if the `SCHEMA_DIR` environment variable is not set and the repository
147/// cannot be located in the current directory path.
148///
149/// # Panics
150///
151/// Panics if the current working directory cannot be determined or contains invalid UTF-8.
152fn get_schema_dir() -> anyhow::Result<String> {
153    std::env::var("SCHEMA_DIR").or_else(|_| {
154        let nautilus_git_repo_name = "posei_trader";
155        let binding = std::env::current_dir().unwrap();
156        let current_dir = binding.to_str().unwrap();
157        match current_dir.find(nautilus_git_repo_name){
158            Some(index) => {
159                let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
160                Ok(schema_path)
161            }
162            None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
163        }
164    })
165}
166
167/// Initializes the Postgres database by creating schema, roles, and executing SQL files from `schema_dir`.
168///
169/// # Errors
170///
171/// Returns an error if any SQL execution or file system operation fails.
172///
173/// # Panics
174///
175/// Panics if `schema_dir` is missing and cannot be determined or if other unwraps fail.
176pub async fn init_postgres(
177    pg: &PgPool,
178    database: String,
179    password: String,
180    schema_dir: Option<String>,
181) -> anyhow::Result<()> {
182    log::info!("Initializing Postgres database with target permissions and schema");
183
184    // Create public schema
185    match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
186        .execute(pg)
187        .await
188    {
189        Ok(_) => log::info!("Schema public created successfully"),
190        Err(e) => log::error!("Error creating schema public: {e:?}"),
191    }
192
193    // Create role if not exists
194    match sqlx::query(format!("CREATE ROLE {database} PASSWORD '{password}' LOGIN;").as_str())
195        .execute(pg)
196        .await
197    {
198        Ok(_) => log::info!("Role {database} created successfully"),
199        Err(e) => {
200            if e.to_string().contains("already exists") {
201                log::info!("Role {database} already exists");
202            } else {
203                log::error!("Error creating role {database}: {e:?}");
204            }
205        }
206    }
207
208    // Execute all the sql files in schema dir
209    let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
210    let mut sql_files =
211        std::fs::read_dir(schema_dir)?.collect::<Result<Vec<_>, std::io::Error>>()?;
212    for file in &mut sql_files {
213        let file_name = file.file_name();
214        log::info!("Executing schema file: {file_name:?}");
215        let file_path = file.path();
216        let sql_content = std::fs::read_to_string(file_path.clone())?;
217        // if filename is functions.sql, split by plpgsql; if not then by ;
218        let delimiter = match file_name.to_str() {
219            Some("functions.sql") => "$$ LANGUAGE plpgsql;",
220            _ => ";",
221        };
222        let sql_statements = sql_content
223            .split(delimiter)
224            .filter(|s| !s.trim().is_empty())
225            .map(|s| format!("{s}{delimiter}"));
226
227        for sql_statement in sql_statements {
228            sqlx::query(&sql_statement)
229                .execute(pg)
230                .await
231                .map_err(|e| {
232                    if e.to_string().contains("already exists") {
233                        log::info!("Already exists error on statement, skipping");
234                    } else {
235                        panic!("Error executing statement {sql_statement} with error: {e:?}")
236                    }
237                })
238                .unwrap();
239        }
240    }
241
242    // Grant connect
243    match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str())
244        .execute(pg)
245        .await
246    {
247        Ok(_) => log::info!("Connect privileges granted to role {database}"),
248        Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
249    }
250
251    // Grant all schema privileges to the role
252    match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
253        .execute(pg)
254        .await
255    {
256        Ok(_) => log::info!("All schema privileges granted to role {database}"),
257        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
258    }
259
260    // Grant all table privileges to the role
261    match sqlx::query(
262        format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(),
263    )
264    .execute(pg)
265    .await
266    {
267        Ok(_) => log::info!("All tables privileges granted to role {database}"),
268        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
269    }
270
271    // Grant all sequence privileges to the role
272    match sqlx::query(
273        format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
274    )
275    .execute(pg)
276    .await
277    {
278        Ok(_) => log::info!("All sequences privileges granted to role {database}"),
279        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
280    }
281
282    // Grant all function privileges to the role
283    match sqlx::query(
284        format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(),
285    )
286    .execute(pg)
287    .await
288    {
289        Ok(_) => log::info!("All functions privileges granted to role {database}"),
290        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
291    }
292
293    Ok(())
294}
295
296/// Drops the Postgres database with the given name using the provided connection pool.
297///
298/// # Errors
299///
300/// Returns an error if the DROP DATABASE command fails.
301pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
302    // Execute drop owned
303    match sqlx::query(format!("DROP OWNED BY {database}").as_str())
304        .execute(pg)
305        .await
306    {
307        Ok(_) => log::info!("Dropped owned objects by role {database}"),
308        Err(e) => log::error!("Error dropping owned by role {database}: {e:?}"),
309    }
310
311    // Revoke connect
312    match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str())
313        .execute(pg)
314        .await
315    {
316        Ok(_) => log::info!("Revoked connect privileges from role {database}"),
317        Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
318    }
319
320    // Revoke privileges
321    match sqlx::query(
322        format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(),
323    )
324    .execute(pg)
325    .await
326    {
327        Ok(_) => log::info!("Revoked all privileges from role {database}"),
328        Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
329    }
330
331    // Execute drop schema
332    match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
333        .execute(pg)
334        .await
335    {
336        Ok(_) => log::info!("Dropped schema public"),
337        Err(e) => log::error!("Error dropping schema public: {e:?}"),
338    }
339
340    // Drop role
341    match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
342        .execute(pg)
343        .await
344    {
345        Ok(_) => log::info!("Dropped role {database}"),
346        Err(e) => log::error!("Error dropping role {database}: {e:?}"),
347    }
348    Ok(())
349}