Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model_dir arg to testing functions #22

Merged
merged 8 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/deployment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ like this:
COPY entrypoint.sh /usr/local/bin/
RUN python -m pip install \
gunicorn \
inference-server \
shipping-forecast # Our package implementing the hooks

EXPOSE 8080
Expand Down
15 changes: 14 additions & 1 deletion docs/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ Here we can use any serializer compatible with :mod:`sagemaker.serializers` and

If no serializer or deserializer is configured, bytes data are passed through as is for both input and output.

:func:`inference_server.testing.predict` accepts a ``model_dir`` argument which can used to set the directory containing
the model artifacts to be loaded. At runtime, this directory is always :file:`/opt/ml/model`. In our tests, we may want
to create model artifacts on the fly, for example in a temporary directory using a Pytest fixture, like this::

import pathlib

@pytest.fixture
def model_artifacts_dir(tmp_path) -> pathlib.Path:
dir_ = tmp_path / "model"
dir_.mkdir()
# instantiate a model object and serialize as 1 or more files to the directory
...
return dir_


Testing model predictions (low-level API)
-----------------------------------------
Expand All @@ -63,7 +77,6 @@ Instead of using the high-level testing API, we can also use invoke requests sim
assert response.json() == expected_prediction



Verifying plugin registration
-----------------------------

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ dependencies = [
[project.optional-dependencies]

docs = [
"pytest", # Because we import this in inference_server.testing
"sphinx",
"sphinx-rtd-theme",
]
Expand All @@ -81,6 +82,7 @@ linting = [
"isort",
"mypy",
"pre-commit",
"pytest", # Because we import this in inference_server.testing
]


Expand Down
3 changes: 2 additions & 1 deletion src/inference_server/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def model_fn(model_dir: str) -> ModelType:
This function will be called when the server starts up. Here, ``ModelType`` can be any Python class corresponding to
the model, for example :class:`sklearn.tree.DecisionTreeClassifier`.

:param model_dir: Local filesystem directory containing the model files
:param model_dir: Local filesystem directory containing the model files. This is always :file:`/opt/ml/model` when
invoked by **inference-server**.
"""
raise NotImplementedError

Expand Down
24 changes: 19 additions & 5 deletions src/inference_server/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
"""

import io
import pathlib
from types import ModuleType
from typing import Any, Callable, Optional, Protocol, Tuple, Type, Union

import botocore.response # type: ignore[import-untyped]
import pluggy
import pytest
import werkzeug.test

import inference_server
Expand Down Expand Up @@ -79,12 +81,18 @@ def deserialize(self, stream: "botocore.response.StreamingBody", content_type: s


def predict(
data: Any, serializer: Optional[ImplementsSerialize] = None, deserializer: Optional[ImplementsDeserialize] = None
data: Any,
*,
model_dir: Optional[pathlib.Path] = None,
serializer: Optional[ImplementsSerialize] = None,
deserializer: Optional[ImplementsDeserialize] = None,
) -> Any:
"""
Invoke the model and return a prediction

:param data: Model input data
:param model_dir: Optional pass a custom model directory to load the model from. Default is
:file:`/opt/ml/model/`.
:param serializer: Optional. A serializer for sending the data as bytes to the model server. Should be compatible
with :class:`sagemaker.serializers.BaseSerializer`. Default: bytes pass-through.
:param deserializer: Optional. A deserializer for processing the prediction as sent by the model server. Should be
Expand All @@ -98,7 +106,7 @@ def predict(
"Content-Type": serializer.CONTENT_TYPE, # The serializer declares the content-type of the input data
"Accept": ", ".join(deserializer.ACCEPT), # The deserializer dictates the content-type of the prediction
}
prediction_response = post_invocations(data=serialized_data, headers=http_headers)
prediction_response = post_invocations(model_dir=model_dir, data=serialized_data, headers=http_headers)
prediction_stream = botocore.response.StreamingBody(
raw_stream=io.BytesIO(prediction_response.data),
content_length=prediction_response.content_length,
Expand All @@ -117,15 +125,21 @@ def client() -> werkzeug.test.Client:
return werkzeug.test.Client(inference_server.create_app())


def post_invocations(**kwargs) -> werkzeug.test.TestResponse:
def post_invocations(*, model_dir: Optional[pathlib.Path] = None, **kwargs) -> werkzeug.test.TestResponse:
"""
Send an HTTP POST request to ``/invocations`` using a test HTTP client and return the response

This function should be used to verify an inference request using the full **inference-server** logic.

:param kwargs: Keyword arguments passed to :meth:`werkzeug.test.Client.post`
:param model_dir: Optional pass a custom model directory to load the model from. Default is :file:`/opt/ml/model/`.
:param kwargs: Keyword arguments passed to :meth:`werkzeug.test.Client.post`
"""
response = client().post("/invocations", **kwargs)
# pytest should be available when we are using inference_server.testing
with pytest.MonkeyPatch.context() as monkeypatch:
if model_dir:
monkeypatch.setattr(inference_server, "_MODEL_DIR", str(model_dir))
response = client().post("/invocations", **kwargs)

assert response.status_code == 200
return response

Expand Down
49 changes: 48 additions & 1 deletion tests/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import pathlib
from typing import Tuple

import botocore.response
Expand All @@ -22,6 +22,14 @@ def test_package_has_version():
assert inference_server.__version__ is not None


@pytest.fixture(autouse=True)
def reset_caches():
try:
yield
finally:
inference_server._model.cache_clear()


@pytest.fixture
def client():
return inference_server.testing.client()
Expand All @@ -46,6 +54,26 @@ def ping_fn(model):
pm.unregister(PingPlugin)


@pytest.fixture
def model_using_dir():
class ModelPlugin:
"""Plugin which just defines a model_fn"""

@staticmethod
@inference_server.plugin_hook()
def model_fn(model_dir: str):
"""Model function for testing we are passing a custom directory"""
assert model_dir != "/opt/ml/model"
return lambda data: data

pm = inference_server.testing.plugin_manager()
pm.register(ModelPlugin)
try:
yield
finally:
pm.unregister(ModelPlugin)


def test_version():
"""Test that the package has a version"""
assert inference_server.__version__ is not None
Expand Down Expand Up @@ -80,6 +108,17 @@ def test_invocations():
assert response.headers["Content-Type"] == "application/octet-stream"


def test_invocations_custom_model_dir(model_using_dir):
"""Test the default plugin (which passes through any input bytes) using low-level testing.post_invocations"""
data = b"What's the shipping forecast for tomorrow"
model_dir = pathlib.Path(__file__).parent

response = inference_server.testing.post_invocations(
data=data, model_dir=model_dir, headers={"Accept": "application/octet-stream"}
)
assert response.data == data


def test_prediction_custom_serializer():
"""Test the default plugin again, now using high-level testing.predict"""

Expand Down Expand Up @@ -115,6 +154,14 @@ def test_prediction_no_serializer():
assert prediction == input_data


def test_prediction_model_dir(model_using_dir):
input_data = b"What's the shipping forecast for tomorrow"
model_dir = pathlib.Path(__file__).parent

prediction = inference_server.testing.predict(input_data, model_dir=model_dir)
assert prediction == input_data


def test_execution_parameters(client):
response = client.get("/execution-parameters")
assert response.data == b'{"BatchStrategy":"MultiRecord","MaxConcurrentTransforms":1,"MaxPayloadInMB":6}'
Expand Down