Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A variety of improvements around tool parameter modeling (second try) #19027

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
16 changes: 16 additions & 0 deletions lib/galaxy/tool_util/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
encode,
encode_test,
fill_static_defaults,
landing_decode,
landing_encode,
)
from .factory import (
from_input_source,
Expand Down Expand Up @@ -39,6 +41,7 @@
HiddenParameterModel,
IntegerParameterModel,
LabelValue,
RawStateDict,
RepeatParameterModel,
RulesParameterModel,
SelectParameterModel,
Expand All @@ -49,15 +52,20 @@
ToolParameterT,
validate_against_model,
validate_internal_job,
validate_internal_landing_request,
validate_internal_request,
validate_internal_request_dereferenced,
validate_landing_request,
validate_request,
validate_test_case,
validate_workflow_step,
validate_workflow_step_linked,
ValidationFunctionT,
)
from .state import (
JobInternalToolState,
LandingRequestInternalToolState,
LandingRequestToolState,
RequestInternalDereferencedToolState,
RequestInternalToolState,
RequestToolState,
Expand Down Expand Up @@ -113,10 +121,14 @@
"ConditionalParameterModel",
"ConditionalWhen",
"RepeatParameterModel",
"RawStateDict",
"ValidationFunctionT",
"validate_against_model",
"validate_internal_job",
"validate_internal_landing_request",
"validate_internal_request",
"validate_internal_request_dereferenced",
"validate_landing_request",
"validate_request",
"validate_test_case",
"validate_workflow_step",
Expand All @@ -130,6 +142,8 @@
"RequestToolState",
"RequestInternalToolState",
"RequestInternalDereferencedToolState",
"LandingRequestToolState",
"LandingRequestInternalToolState",
"flat_state_path",
"keys_starting_with",
"visit_input_values",
Expand All @@ -139,6 +153,8 @@
"encode",
"encode_test",
"fill_static_defaults",
"landing_decode",
"landing_encode",
"dereference",
"WorkflowStepToolState",
"WorkflowStepLinkedToolState",
Expand Down
11 changes: 11 additions & 0 deletions lib/galaxy/tool_util/parameters/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from typing import (
Any,
cast,
List,
Optional,
Expand All @@ -15,6 +16,7 @@

# https://stackoverflow.com/questions/56832881/check-if-a-field-is-typing-optional
from typing_extensions import (
Annotated,
get_args,
get_origin,
)
Expand Down Expand Up @@ -46,3 +48,12 @@ def cast_as_type(arg) -> Type:

def is_optional(field) -> bool:
return get_origin(field) is Union and type(None) in get_args(field)


def expand_annotation(field: Type, new_annotations: List[Any]) -> Type:
is_annotation = get_origin(field) is Annotated
if is_annotation:
args = get_args(field) # noqa: F841
return Annotated[tuple([args[0], *args[1:], *new_annotations])] # type: ignore[return-value]
else:
return Annotated[tuple([field, *new_annotations])] # type: ignore[return-value]
173 changes: 110 additions & 63 deletions lib/galaxy/tool_util/parameters/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,15 @@
)
from .state import (
JobInternalToolState,
LandingRequestInternalToolState,
LandingRequestToolState,
RequestInternalDereferencedToolState,
RequestInternalToolState,
RequestToolState,
TestCaseToolState,
)
from .visitor import (
Callback,
validate_explicit_conditional_test_value,
visit_input_values,
VISITOR_NO_REPLACEMENT,
Expand All @@ -54,40 +57,22 @@
log = logging.getLogger(__name__)


DecodeFunctionT = Callable[[str], int]
EncodeFunctionT = Callable[[int], str]
DereferenceCallable = Callable[[DataRequestUri], DataRequestInternalHda]
# interfaces for adapting test data dictionaries to tool request dictionaries
# e.g. {class: File, path: foo.bed} => {src: hda, id: ab1235cdfea3}
AdaptDatasets = Callable[[JsonTestDatasetDefDict], DataRequestHda]
AdaptCollections = Callable[[JsonTestCollectionDefDict], DataCollectionRequest]


def decode(
external_state: RequestToolState, input_models: ToolParameterBundle, decode_id: Callable[[str], int]
) -> RequestInternalToolState:
"""Prepare an external representation of tool state (request) for storing in the database (request_internal)."""
"""Prepare an internal representation of tool state (request_internal) for storing in the database."""

external_state.validate(input_models)

def decode_src_dict(src_dict: dict):
if "id" in src_dict:
decoded_dict = src_dict.copy()
decoded_dict["id"] = decode_id(src_dict["id"])
return decoded_dict
else:
return src_dict

def decode_callback(parameter: ToolParameterT, value: Any):
if parameter.parameter_type == "gx_data":
if value is None:
return VISITOR_NO_REPLACEMENT
data_parameter = cast(DataParameterModel, parameter)
if data_parameter.multiple:
assert isinstance(value, list), str(value)
return list(map(decode_src_dict, value))
else:
assert isinstance(value, dict), str(value)
return decode_src_dict(value)
elif parameter.parameter_type == "gx_data_collection":
if value is None:
return VISITOR_NO_REPLACEMENT
assert isinstance(value, dict), str(value)
return decode_src_dict(value)
else:
return VISITOR_NO_REPLACEMENT

decode_callback = _decode_callback_for(decode_id)
internal_state_dict = visit_input_values(
input_models,
external_state,
Expand All @@ -100,44 +85,53 @@ def decode_callback(parameter: ToolParameterT, value: Any):


def encode(
external_state: RequestInternalToolState, input_models: ToolParameterBundle, encode_id: Callable[[int], str]
internal_state: RequestInternalToolState, input_models: ToolParameterBundle, encode_id: EncodeFunctionT
) -> RequestToolState:
"""Prepare an external representation of tool state (request) for storing in the database (request_internal)."""

def encode_src_dict(src_dict: dict):
if "id" in src_dict:
encoded_dict = src_dict.copy()
encoded_dict["id"] = encode_id(src_dict["id"])
return encoded_dict
else:
return src_dict

def encode_callback(parameter: ToolParameterT, value: Any):
if parameter.parameter_type == "gx_data":
data_parameter = cast(DataParameterModel, parameter)
if data_parameter.multiple:
assert isinstance(value, list), str(value)
return list(map(encode_src_dict, value))
else:
assert isinstance(value, dict), str(value)
return encode_src_dict(value)
elif parameter.parameter_type == "gx_data_collection":
assert isinstance(value, dict), str(value)
return encode_src_dict(value)
else:
return VISITOR_NO_REPLACEMENT
"""Prepare an external representation of tool state (request) from persisted state in the database (request_internal)."""

encode_callback = _encode_callback_for(encode_id)
request_state_dict = visit_input_values(
input_models,
external_state,
internal_state,
encode_callback,
)
request_state = RequestToolState(request_state_dict)
request_state.validate(input_models)
return request_state


DereferenceCallable = Callable[[DataRequestUri], DataRequestInternalHda]
def landing_decode(
external_state: LandingRequestToolState, input_models: ToolParameterBundle, decode_id: Callable[[str], int]
) -> LandingRequestInternalToolState:
"""Prepare an external representation of tool state (request) for storing in the database (request_internal)."""

external_state.validate(input_models)
decode_callback = _decode_callback_for(decode_id)
internal_state_dict = visit_input_values(
input_models,
external_state,
decode_callback,
)

internal_request_state = LandingRequestInternalToolState(internal_state_dict)
internal_request_state.validate(input_models)
return internal_request_state


def landing_encode(
internal_state: LandingRequestInternalToolState, input_models: ToolParameterBundle, encode_id: EncodeFunctionT
) -> LandingRequestToolState:
"""Prepare an external representation of tool state (request) for storing in the database (request_internal)."""

encode_callback = _encode_callback_for(encode_id)
request_state_dict = visit_input_values(
input_models,
internal_state,
encode_callback,
)
request_state = LandingRequestToolState(request_state_dict)
request_state.validate(input_models)
return request_state


def dereference(
Expand Down Expand Up @@ -177,12 +171,6 @@ def dereference_callback(parameter: ToolParameterT, value: Any):
return request_state


# interfaces for adapting test data dictionaries to tool request dictionaries
# e.g. {class: File, path: foo.bed} => {src: hda, id: ab1235cdfea3}
AdaptDatasets = Callable[[JsonTestDatasetDefDict], DataRequestHda]
AdaptCollections = Callable[[JsonTestCollectionDefDict], DataCollectionRequest]


def encode_test(
test_case_state: TestCaseToolState,
input_models: ToolParameterBundle,
Expand Down Expand Up @@ -324,7 +312,6 @@ def _fill_default_for(tool_state: Dict[str, Any], parameter: ToolParameterT) ->
)
test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value)
when = _select_which_when(conditional, test_value, conditional_state)
test_parameter = conditional.test_parameter
_fill_default_for(conditional_state, test_parameter)
_fill_defaults(conditional_state, when)
elif parameter_type in ["gx_repeat"]:
Expand Down Expand Up @@ -358,3 +345,63 @@ def _select_which_when(
raise Exception(
f"Invalid conditional test value ({test_value}) for parameter ({conditional.test_parameter.name})"
)


def _encode_callback_for(encode_id: EncodeFunctionT) -> Callback:

def encode_src_dict(src_dict: dict):
if "id" in src_dict:
encoded_dict = src_dict.copy()
encoded_dict["id"] = encode_id(src_dict["id"])
return encoded_dict
else:
return src_dict

def encode_callback(parameter: ToolParameterT, value: Any):
if parameter.parameter_type == "gx_data":
data_parameter = cast(DataParameterModel, parameter)
if data_parameter.multiple:
assert isinstance(value, list), str(value)
return list(map(encode_src_dict, value))
else:
assert isinstance(value, dict), str(value)
return encode_src_dict(value)
elif parameter.parameter_type == "gx_data_collection":
assert isinstance(value, dict), str(value)
return encode_src_dict(value)
else:
return VISITOR_NO_REPLACEMENT

return encode_callback


def _decode_callback_for(decode_id: DecodeFunctionT) -> Callback:

def decode_src_dict(src_dict: dict):
if "id" in src_dict:
decoded_dict = src_dict.copy()
decoded_dict["id"] = decode_id(src_dict["id"])
return decoded_dict
else:
return src_dict

def decode_callback(parameter: ToolParameterT, value: Any):
if parameter.parameter_type == "gx_data":
if value is None:
return VISITOR_NO_REPLACEMENT
data_parameter = cast(DataParameterModel, parameter)
if data_parameter.multiple:
assert isinstance(value, list), str(value)
return list(map(decode_src_dict, value))
else:
assert isinstance(value, dict), str(value)
return decode_src_dict(value)
elif parameter.parameter_type == "gx_data_collection":
if value is None:
return VISITOR_NO_REPLACEMENT
assert isinstance(value, dict), str(value)
return decode_src_dict(value)
else:
return VISITOR_NO_REPLACEMENT

return decode_callback
Loading
Loading