Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 63 additions & 45 deletions payjoin-cli/src/app/v2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::fmt;
use std::sync::{Arc, Mutex};
use std::sync::Arc;

use anyhow::{anyhow, Context, Result};
use payjoin::bitcoin::consensus::encode::serialize_hex;
Expand Down Expand Up @@ -40,7 +40,6 @@ pub(crate) struct App {
db: Arc<Database>,
wallet: BitcoindWallet,
interrupt: watch::Receiver<()>,
relay_manager: Arc<Mutex<RelayManager>>,
}

trait StatusText {
Expand Down Expand Up @@ -140,11 +139,10 @@ impl<Status: StatusText> fmt::Display for SessionHistoryRow<Status> {
impl AppTrait for App {
async fn new(config: Config) -> Result<Self> {
let db = Arc::new(Database::create(&config.db_path)?);
let relay_manager = Arc::new(Mutex::new(RelayManager::new()));
let (interrupt_tx, interrupt_rx) = watch::channel(());
tokio::spawn(handle_interrupt(interrupt_tx));
let wallet = BitcoindWallet::new(&config.bitcoind).await?;
let app = Self { config, db, wallet, interrupt: interrupt_rx, relay_manager };
let app = Self { config, db, wallet, interrupt: interrupt_rx };
app.wallet()
.network()
.context("Failed to connect to bitcoind. Check config RPC connection.")?;
Expand All @@ -153,7 +151,6 @@ impl AppTrait for App {

fn wallet(&self) -> BitcoindWallet { self.wallet.clone() }

#[allow(clippy::incompatible_msrv)]
async fn send_payjoin(&self, bip21: &str, fee_rate: FeeRate) -> Result<()> {
use payjoin::UriExt;
let uri = Uri::try_from(bip21)
Expand Down Expand Up @@ -254,10 +251,10 @@ impl AppTrait for App {

async fn receive_payjoin(&self, amount: Amount) -> Result<()> {
let address = self.wallet().get_new_address()?;
let ohttp_keys =
unwrap_ohttp_keys_or_else_fetch(&self.config, None, self.relay_manager.clone())
.await?
.ohttp_keys;
let mut relay_manager = RelayManager::new();
let ohttp_keys = unwrap_ohttp_keys_or_else_fetch(&self.config, None, &mut relay_manager)
.await?
.ohttp_keys;
let persister = ReceiverPersister::new(self.db.clone())?;
let session =
ReceiverBuilder::new(address, self.config.v2()?.pj_directory.as_str(), ohttp_keys)?
Expand All @@ -276,7 +273,6 @@ impl AppTrait for App {
Ok(())
}

#[allow(clippy::incompatible_msrv)]
async fn resume_payjoins(&self) -> Result<()> {
let recv_session_ids = self.db.get_recv_session_ids()?;
let send_session_ids = self.db.get_send_session_ids()?;
Expand Down Expand Up @@ -480,11 +476,12 @@ impl App {
session: SendSession,
persister: &SenderPersister,
) -> Result<()> {
let mut relay_manager = RelayManager::new();
match session {
SendSession::WithReplyKey(context) =>
self.post_original_proposal(context, persister).await?,
self.post_original_proposal(context, persister, &mut relay_manager).await?,
SendSession::PollingForProposal(context) =>
self.get_proposed_payjoin_psbt(context, persister).await?,
self.get_proposed_payjoin_psbt(context, persister, &mut relay_manager).await?,
SendSession::Closed(SenderSessionOutcome::Success(proposal)) => {
self.process_pj_response(proposal)?;
return Ok(());
Expand All @@ -498,22 +495,27 @@ impl App {
&self,
sender: Sender<WithReplyKey>,
persister: &SenderPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let (req, ctx) = sender.create_v2_post_request(
self.unwrap_relay_or_else_fetch(Some(&sender.endpoint())).await?.as_str(),
self.unwrap_relay_or_else_fetch(Some(&sender.endpoint()), relay_manager)
.await?
.as_str(),
)?;
let response = self.post_request(req).await?;
println!("Posted original proposal...");
let sender = sender.process_response(&response.bytes().await?, ctx).save(persister)?;
self.get_proposed_payjoin_psbt(sender, persister).await
self.get_proposed_payjoin_psbt(sender, persister, relay_manager).await
}

async fn get_proposed_payjoin_psbt(
&self,
sender: Sender<PollingForProposal>,
persister: &SenderPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let ohttp_relay = self.unwrap_relay_or_else_fetch(Some(&sender.endpoint())).await?;
let ohttp_relay =
self.unwrap_relay_or_else_fetch(Some(&sender.endpoint()), relay_manager).await?;
let mut session = sender.clone();
// Long poll until we get a response
loop {
Expand Down Expand Up @@ -544,9 +546,11 @@ impl App {
&self,
session: Receiver<Initialized>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<Receiver<UncheckedOriginalPayload>> {
let ohttp_relay =
self.unwrap_relay_or_else_fetch(Some(&session.pj_uri().extras.endpoint())).await?;
let ohttp_relay = self
.unwrap_relay_or_else_fetch(Some(&session.pj_uri().extras.endpoint()), relay_manager)
.await?;

let mut session = session;
loop {
Expand Down Expand Up @@ -575,30 +579,31 @@ impl App {
session: ReceiveSession,
persister: &ReceiverPersister,
) -> Result<()> {
let mut relay_manager = RelayManager::new();
let res = {
match session {
ReceiveSession::Initialized(proposal) =>
self.read_from_directory(proposal, persister).await,
self.read_from_directory(proposal, persister, &mut relay_manager).await,
ReceiveSession::UncheckedOriginalPayload(proposal) =>
self.check_proposal(proposal, persister).await,
self.check_proposal(proposal, persister, &mut relay_manager).await,
ReceiveSession::MaybeInputsOwned(proposal) =>
self.check_inputs_not_owned(proposal, persister).await,
self.check_inputs_not_owned(proposal, persister, &mut relay_manager).await,
ReceiveSession::MaybeInputsSeen(proposal) =>
self.check_no_inputs_seen_before(proposal, persister).await,
self.check_no_inputs_seen_before(proposal, persister, &mut relay_manager).await,
ReceiveSession::OutputsUnknown(proposal) =>
self.identify_receiver_outputs(proposal, persister).await,
self.identify_receiver_outputs(proposal, persister, &mut relay_manager).await,
ReceiveSession::WantsOutputs(proposal) =>
self.commit_outputs(proposal, persister).await,
self.commit_outputs(proposal, persister, &mut relay_manager).await,
ReceiveSession::WantsInputs(proposal) =>
self.contribute_inputs(proposal, persister).await,
self.contribute_inputs(proposal, persister, &mut relay_manager).await,
ReceiveSession::WantsFeeRange(proposal) =>
self.apply_fee_range(proposal, persister).await,
self.apply_fee_range(proposal, persister, &mut relay_manager).await,
ReceiveSession::ProvisionalProposal(proposal) =>
self.finalize_proposal(proposal, persister).await,
self.finalize_proposal(proposal, persister, &mut relay_manager).await,
ReceiveSession::PayjoinProposal(proposal) =>
self.send_payjoin_proposal(proposal, persister).await,
self.send_payjoin_proposal(proposal, persister, &mut relay_manager).await,
ReceiveSession::HasReplyableError(error) =>
self.handle_error(error, persister).await,
self.handle_error(error, persister, &mut relay_manager).await,
ReceiveSession::Monitor(proposal) =>
self.monitor_payjoin_proposal(proposal, persister).await,
ReceiveSession::Closed(_) => return Err(anyhow!("Session closed")),
Expand All @@ -612,22 +617,24 @@ impl App {
&self,
session: Receiver<Initialized>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let mut interrupt = self.interrupt.clone();
let receiver = tokio::select! {
res = self.long_poll_fallback(session, persister) => res,
res = self.long_poll_fallback(session, persister, relay_manager) => res,
_ = interrupt.changed() => {
println!("Interrupted. Call the `resume` command to resume all sessions.");
return Err(anyhow!("Interrupted"));
}
}?;
self.check_proposal(receiver, persister).await
self.check_proposal(receiver, persister, relay_manager).await
}

async fn check_proposal(
&self,
proposal: Receiver<UncheckedOriginalPayload>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let wallet = self.wallet();
let proposal = proposal
Expand All @@ -640,13 +647,14 @@ impl App {

println!("Fallback transaction received. Consider broadcasting this to get paid if the Payjoin fails:");
println!("{}", serialize_hex(&proposal.extract_tx_to_schedule_broadcast()));
self.check_inputs_not_owned(proposal, persister).await
self.check_inputs_not_owned(proposal, persister, relay_manager).await
}

async fn check_inputs_not_owned(
&self,
proposal: Receiver<MaybeInputsOwned>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let wallet = self.wallet();
let proposal = proposal
Expand All @@ -656,26 +664,28 @@ impl App {
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))
})
.save(persister)?;
self.check_no_inputs_seen_before(proposal, persister).await
self.check_no_inputs_seen_before(proposal, persister, relay_manager).await
}

async fn check_no_inputs_seen_before(
&self,
proposal: Receiver<MaybeInputsSeen>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let proposal = proposal
.check_no_inputs_seen_before(&mut |input| {
Ok(self.db.insert_input_seen_before(*input)?)
})
.save(persister)?;
self.identify_receiver_outputs(proposal, persister).await
self.identify_receiver_outputs(proposal, persister, relay_manager).await
}

async fn identify_receiver_outputs(
&self,
proposal: Receiver<OutputsUnknown>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let wallet = self.wallet();
let proposal = proposal
Expand All @@ -685,22 +695,24 @@ impl App {
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))
})
.save(persister)?;
self.commit_outputs(proposal, persister).await
self.commit_outputs(proposal, persister, relay_manager).await
}

async fn commit_outputs(
&self,
proposal: Receiver<WantsOutputs>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let proposal = proposal.commit_outputs().save(persister)?;
self.contribute_inputs(proposal, persister).await
self.contribute_inputs(proposal, persister, relay_manager).await
}

async fn contribute_inputs(
&self,
proposal: Receiver<WantsInputs>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let wallet = self.wallet();
let candidate_inputs = wallet.list_unspent()?;
Expand All @@ -714,22 +726,24 @@ impl App {
let selected_input = proposal.try_preserving_privacy(candidate_inputs)?;
let proposal =
proposal.contribute_inputs(vec![selected_input])?.commit_inputs().save(persister)?;
self.apply_fee_range(proposal, persister).await
self.apply_fee_range(proposal, persister, relay_manager).await
}

async fn apply_fee_range(
&self,
proposal: Receiver<WantsFeeRange>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let proposal = proposal.apply_fee_range(None, self.config.max_fee_rate).save(persister)?;
self.finalize_proposal(proposal, persister).await
self.finalize_proposal(proposal, persister, relay_manager).await
}

async fn finalize_proposal(
&self,
proposal: Receiver<ProvisionalProposal>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let wallet = self.wallet();
let proposal = proposal
Expand All @@ -739,16 +753,19 @@ impl App {
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))
})
.save(persister)?;
self.send_payjoin_proposal(proposal, persister).await
self.send_payjoin_proposal(proposal, persister, relay_manager).await
}

async fn send_payjoin_proposal(
&self,
proposal: Receiver<PayjoinProposal>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let (req, ohttp_ctx) = proposal
.create_post_request(self.unwrap_relay_or_else_fetch(None::<&str>).await?.as_str())
.create_post_request(
self.unwrap_relay_or_else_fetch(None::<&str>, relay_manager).await?.as_str(),
)
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
let res = self.post_request(req).await?;
let payjoin_psbt = proposal.psbt().clone();
Expand Down Expand Up @@ -813,14 +830,13 @@ impl App {
async fn unwrap_relay_or_else_fetch(
&self,
directory: Option<impl payjoin::IntoUrl>,
relay_manager: &mut RelayManager,
) -> Result<url::Url> {
let directory = directory.map(|url| url.into_url()).transpose()?;
let selected_relay =
self.relay_manager.lock().expect("Lock should not be poisoned").get_selected_relay();
let ohttp_relay = match selected_relay {
let ohttp_relay = match relay_manager.get_selected_relay() {
Some(relay) => relay,
None =>
unwrap_ohttp_keys_or_else_fetch(&self.config, directory, self.relay_manager.clone())
unwrap_ohttp_keys_or_else_fetch(&self.config, directory, relay_manager)
.await?
.relay_url,
};
Expand All @@ -832,9 +848,11 @@ impl App {
&self,
session: Receiver<HasReplyableError>,
persister: &ReceiverPersister,
relay_manager: &mut RelayManager,
) -> Result<()> {
let (err_req, err_ctx) = session
.create_error_request(self.unwrap_relay_or_else_fetch(None::<&str>).await?.as_str())?;
let (err_req, err_ctx) = session.create_error_request(
self.unwrap_relay_or_else_fetch(None::<&str>, relay_manager).await?.as_str(),
)?;

let err_response = match self.post_request(err_req).await {
Ok(response) => response,
Expand Down
Loading
Loading