Skip to content

Commit

Permalink
feat: Pass payload while building SubMsg (#441)
Browse files Browse the repository at this point in the history
  • Loading branch information
jawoznia committed Oct 14, 2024
1 parent cdbbc57 commit df259f2
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 31 deletions.
105 changes: 83 additions & 22 deletions sylvia-derive/src/contract/communication/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ impl<'a> Reply<'a> {

let methods_declaration = reply_data.iter().map(|data| {
let method_name = &data.handler_id;
let payload_parameters = data.as_payload_parameters();

quote! {
fn #method_name (self) -> #sylvia ::cw_std::SubMsg<CustomMsgT>;
fn #method_name (self, #(#payload_parameters),* ) -> #sylvia ::cw_std::StdResult< #sylvia ::cw_std::SubMsg<CustomMsgT>>;
}
});

Expand Down Expand Up @@ -182,7 +183,7 @@ impl<'a> ReplyVariants<'a> for MsgVariants<'a, GenericParam> {
},
)
}
Some(existing_data) => existing_data.handlers.push(handler),
Some(existing_data) => existing_data.add_second_handler(handler),
None => reply_data.push(ReplyData::new(reply_id, handler, handler_id)),
}
});
Expand Down Expand Up @@ -210,6 +211,41 @@ impl<'a> ReplyData<'a> {
}
}

/// Adds second handler to the reply data.
///
/// # Error
/// This method emits an error if there is already more then one handler.
pub fn add_second_handler(&mut self, new_handler: &'a MsgVariant<'a>) {
let current_handler= match self.handlers.first() {
Some(handler) if self.handlers.len() == 1 => handler,
_ => return,

Check warning on line 221 in sylvia-derive/src/contract/communication/reply.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/reply.rs#L221

Added line #L221 was not covered by tests
};

if current_handler.fields().len() != new_handler.fields().len() {
emit_error!(current_handler.function_name().span(), "Mismatched lenght of method parameters.";
note = self.handler_id.span() => format!("Both {} handlers should have the same number of parameters.", self.handler_id)

Check warning on line 226 in sylvia-derive/src/contract/communication/reply.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/reply.rs#L224-L226

Added lines #L224 - L226 were not covered by tests
// note = existing_handler.handler_id.span() => format!("Previous definition of {} handler", existing_handler.handler_id)
);
}

current_handler
.fields()
.iter()
.skip(1)
.zip(new_handler.fields().iter().skip(1))
.for_each(|(current_field, new_field)|
{
if current_field!= new_field {
emit_error!(current_field.name().span(), "Mismatched parameter in reply handlers.";
note = current_field.name().span() => format!("Parameters for the {} handler have to be the same.", self.handler_id);
note = new_field.name().span() => format!("Previous parameter defined for the {} handler.", self.handler_id)

Check warning on line 241 in sylvia-derive/src/contract/communication/reply.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/reply.rs#L239-L241

Added lines #L239 - L241 were not covered by tests
)
}
});

self.handlers.push(new_handler);

Check warning on line 246 in sylvia-derive/src/contract/communication/reply.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/reply.rs#L246

Added line #L246 was not covered by tests
}

/// Emits success and failure match arms for a single `ReplyId`.
fn emit_match_arms(&self, contract: &Type, generics: &[&GenericParam]) -> TokenStream {
let Self {
Expand Down Expand Up @@ -269,14 +305,19 @@ impl<'a> ReplyData<'a> {

let method_name = handler_id;
let reply_on = self.emit_cw_reply_on();
let payload_parameters = self.as_payload_parameters();
let payload_values= self.as_payload_values();

quote! {
fn #method_name (self) -> #sylvia ::cw_std::SubMsg<CustomMsgT> {
#sylvia ::cw_std::SubMsg {
fn #method_name (self, #(#payload_parameters),* ) -> #sylvia ::cw_std::StdResult< #sylvia ::cw_std::SubMsg<CustomMsgT>> {
let payload = #sylvia ::cw_std::to_json_binary(&( #(#payload_values),* ))?;

Ok( #sylvia ::cw_std::SubMsg {
reply_on: #reply_on ,
id: #reply_id ,
payload,
..self
}
})
}
}
}
Expand All @@ -291,19 +332,30 @@ impl<'a> ReplyData<'a> {

let method_name = handler_id;
let reply_on = self.emit_cw_reply_on();
let payload_parameters = self.as_payload_parameters();
let payload_values= self.as_payload_values();

quote! {
fn #method_name (self) -> #sylvia ::cw_std::SubMsg<CustomMsgT> {
#sylvia ::cw_std::SubMsg {
fn #method_name (self, #(#payload_parameters),* ) -> #sylvia ::cw_std::StdResult< #sylvia ::cw_std::SubMsg<CustomMsgT>> {
let payload = #sylvia ::cw_std::to_json_binary(&( #(#payload_values),* ))?;
Ok( #sylvia ::cw_std::SubMsg {
reply_on: #reply_on ,
id: #reply_id ,
msg: self.into(),
payload: Default::default(),
payload,
gas_limit: None,
}
})
}
}
}

fn as_payload_parameters(&self) -> impl Iterator<Item = TokenStream> + 'a {
self.handlers.first().unwrap().fields().iter().skip(1).map(MsgField::emit_method_field)
}

fn as_payload_values(&self) -> impl Iterator<Item = &Ident> {
self.handlers.first().unwrap().fields().iter().skip(1).map(MsgField::name)
}
}

/// Emits match arm for [ReplyOn::Success].
Expand All @@ -318,7 +370,7 @@ fn emit_success_match_arm(handlers: &[&MsgVariant], contract_turbofish: &Type) -
}) {
Some(handler) if handler.msg_attr().reply_on() == ReplyOn::Success => {
let function_name = handler.function_name();
let payload = handler.emit_payload_parameters();
let payload_names = handler.emit_payload_names();
let payload_deserialization = handler.emit_payload_deserialization();

quote! {
Expand All @@ -327,20 +379,20 @@ fn emit_success_match_arm(handlers: &[&MsgVariant], contract_turbofish: &Type) -
let #sylvia ::cw_std::SubMsgResponse { events, data, msg_responses} = sub_msg_resp;
#payload_deserialization

#contract_turbofish ::new(). #function_name ((deps, env, gas_used, events, msg_responses).into(), data, #payload )
#contract_turbofish ::new(). #function_name ((deps, env, gas_used, events, msg_responses).into(), data, #payload_names )
}
}
}
Some(handler) if handler.msg_attr().reply_on() == ReplyOn::Always => {
let function_name = handler.function_name();
let payload = handler.emit_payload_parameters();
let payload_names = handler.emit_payload_names();
let payload_deserialization = handler.emit_payload_deserialization();

quote! {
#sylvia ::cw_std::SubMsgResult::Ok(_) => {
#payload_deserialization

#contract_turbofish ::new(). #function_name ((deps, env, gas_used, vec![], vec![]).into(), result, #payload )
#contract_turbofish ::new(). #function_name ((deps, env, gas_used, vec![], vec![]).into(), result, #payload_names )
}
}
}
Expand Down Expand Up @@ -371,27 +423,27 @@ fn emit_failure_match_arm(handlers: &[&MsgVariant], contract_turbofish: &Type) -
}) {
Some(handler) if handler.msg_attr().reply_on() == ReplyOn::Failure => {
let function_name = handler.function_name();
let payload = handler.emit_payload_parameters();
let payload_names = handler.emit_payload_names();
let payload_deserialization = handler.emit_payload_deserialization();

quote! {
#sylvia ::cw_std::SubMsgResult::Err(error) => {
#payload_deserialization

#contract_turbofish ::new(). #function_name ((deps, env, gas_used, vec![], vec![]).into(), error, #payload )
#contract_turbofish ::new(). #function_name ((deps, env, gas_used, vec![], vec![]).into(), error, #payload_names )
}
}
}
Some(handler) if handler.msg_attr().reply_on() == ReplyOn::Always => {
let function_name = handler.function_name();
let payload = handler.emit_payload_parameters();
let payload_names = handler.emit_payload_names();
let payload_deserialization = handler.emit_payload_deserialization();

quote! {
#sylvia ::cw_std::SubMsgResult::Err(_) => {
#payload_deserialization

#contract_turbofish ::new(). #function_name ((deps, env, gas_used, vec![], vec![]).into(), result, #payload )
#contract_turbofish ::new(). #function_name ((deps, env, gas_used, vec![], vec![]).into(), result, #payload_names )
}
}
}
Expand All @@ -405,8 +457,9 @@ fn emit_failure_match_arm(handlers: &[&MsgVariant], contract_turbofish: &Type) -

trait ReplyVariant<'a> {
fn as_variant_handlers_pair(&'a self) -> Vec<(&'a MsgVariant<'a>, &'a Ident)>;
fn emit_payload_parameters(&self) -> TokenStream;
fn emit_payload_deserialization(&self) -> TokenStream;
fn emit_payload_names(&'a self) -> TokenStream;
fn emit_payload_deserialization(&'a self) -> TokenStream;
fn emit_payload_fields(&'a self) -> impl Iterator<Item = TokenStream>;
}

impl<'a> ReplyVariant<'a> for MsgVariant<'a> {
Expand All @@ -425,16 +478,16 @@ impl<'a> ReplyVariant<'a> for MsgVariant<'a> {
variant_handler_id_pair
}

fn emit_payload_parameters(&self) -> TokenStream {
fn emit_payload_names(&self) -> TokenStream {
if self
.fields()
.iter()
.any(|field| field.contains_attribute(SylviaAttribute::Payload))
{
quote! { payload }
} else {
let deserialized_payload = self.fields().iter().skip(1).map(MsgField::name);
quote! { #(#deserialized_payload),* }
let deserialized_payload_names = self.fields().iter().skip(1).map(MsgField::name);
quote! { #(#deserialized_payload_names),* }
}
}

Expand All @@ -454,6 +507,13 @@ impl<'a> ReplyVariant<'a> for MsgVariant<'a> {
let ( #(#deserialized_names),* ) = #sylvia ::cw_std::from_json(&payload)?;
}
}

fn emit_payload_fields(&'a self) -> impl Iterator<Item = TokenStream> {
self.fields()

Check warning on line 512 in sylvia-derive/src/contract/communication/reply.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/reply.rs#L511-L512

Added lines #L511 - L512 were not covered by tests
.iter()
.skip(1)
.map(MsgField::emit_method_field)

Check warning on line 515 in sylvia-derive/src/contract/communication/reply.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/contract/communication/reply.rs#L515

Added line #L515 was not covered by tests
}
}

/// Maps self to an [Ident] reply id.
Expand All @@ -467,3 +527,4 @@ impl AsReplyId for Ident {
Ident::new(&reply_id, self.span())
}
}

2 changes: 1 addition & 1 deletion sylvia-derive/src/types/msg_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use syn::visit::Visit;
use syn::{Attribute, Ident, Pat, PatType, Type};

/// Representation of single message variant field
#[derive(Debug)]
#[derive(PartialEq, Debug)]
pub struct MsgField<'a> {
name: &'a Ident,
ty: &'a Type,
Expand Down
17 changes: 9 additions & 8 deletions sylvia/tests/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ where
let sub_msg = InstantiateBuilder::noop_contract(remote_code_id)?
.with_label("noop")
.build()
.remote_instantiated()
.with_payload(to_json_binary(&payload)?);
.remote_instantiated(to_json_binary(&payload)?)?;
// Blocked by https://github.com/CosmWasm/cw-multi-test/pull/216.
// Payload is not currently forwarded in the MultiTest.
// .remote_instantiated(payload)?;

Ok(Response::new().add_submessage(sub_msg))
}
Expand All @@ -125,7 +127,7 @@ where
.executor()
.noop(should_fail)?
.build()
.success();
.success(Binary::default())?;

Ok(Response::new().add_submessage(msg))
}
Expand All @@ -142,7 +144,7 @@ where
.executor()
.noop(should_fail)?
.build()
.failure();
.failure(Binary::default())?;

Ok(Response::new().add_submessage(msg))
}
Expand All @@ -159,7 +161,7 @@ where
.executor()
.noop(should_fail)?
.build()
.both();
.both(Binary::default())?;

Ok(Response::new().add_submessage(msg))
}
Expand All @@ -179,8 +181,7 @@ where
.executor()
.noop(should_fail)?
.build()
.always()
.with_payload(payload);
.always(payload)?;

Ok(Response::new().add_submessage(msg))
}
Expand Down Expand Up @@ -275,7 +276,7 @@ where
to_address: remote_addr.as_ref().to_string(),
amount: vec![],
});
let submsg = cosmos_msg.always();
let submsg = cosmos_msg.always(Binary::default())?;
Ok(Response::new().add_submessage(submsg))
}
}
Expand Down

0 comments on commit df259f2

Please sign in to comment.