Skip to content

Commit

Permalink
Rate limit implementation (#303)
Browse files Browse the repository at this point in the history
**Description:**
Implementation of address based global rate limiting **option**.

Rate limiting is an optional security feature that controls API request
frequency on a remote address basis. It's enabled by setting the
`STAC_FASTAPI_RATE_LIMIT` environment variable, e.g., `500/minute`. This
limits each client to 500 requests per minute, helping prevent abuse and
maintain API stability. Implementation examples are available in the
[examples/rate_limit](examples/rate_limit) directory.

**PR Checklist:**

- [x] Code is formatted and linted (run `pre-commit run --all-files`)
- [x] Tests pass (run `make test`)
- [x] Documentation has been updated to reflect changes, if applicable
- [x] Changes are added to the changelog
  • Loading branch information
pedro-cf authored Oct 6, 2024
1 parent 2d6cb4d commit bbbba05
Show file tree
Hide file tree
Showing 11 changed files with 254 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Added

- Added `datetime_frequency_interval` parameter for `datetime_frequency` aggregation. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294)
- Added rate limiting functionality with configurable limits using environment variable `STAC_FASTAPI_RATE_LIMIT`, example: `500/minute`. [#303](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/303)

### Changed

Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,4 +383,8 @@ Available aggregations are:
- geometry_geohash_grid_frequency ([geohash grid](https://opensearch.org/docs/latest/aggregations/bucket/geohash-grid/) on Item.geometry)
- geometry_geotile_grid_frequency ([geotile grid](https://opensearch.org/docs/latest/aggregations/bucket/geotile-grid/) on Item.geometry)

Support for additional fields and new aggregations can be added in the associated `database_logic.py` file.
Support for additional fields and new aggregations can be added in the associated `database_logic.py` file.

## Rate Limiting

Rate limiting is an optional security feature that controls API request frequency on a remote address basis. It's enabled by setting the `STAC_FASTAPI_RATE_LIMIT` environment variable, e.g., `500/minute`. This limits each client to 500 requests per minute, helping prevent abuse and maintain API stability. Implementation examples are available in the [examples/rate_limit](examples/rate_limit) directory.
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ services:
- ES_USE_SSL=false
- ES_VERIFY_CERTS=false
- BACKEND=opensearch
- STAC_FASTAPI_RATE_LIMIT=200/minute
ports:
- "8082:8082"
volumes:
Expand Down
94 changes: 94 additions & 0 deletions examples/rate_limit/docker-compose.rate_limit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
version: '3.9'

services:
app-elasticsearch:
container_name: stac-fastapi-es
image: stac-utils/stac-fastapi-es
restart: always
build:
context: .
dockerfile: dockerfiles/Dockerfile.dev.es
environment:
- STAC_FASTAPI_TITLE=stac-fastapi-elasticsearch
- STAC_FASTAPI_DESCRIPTION=A STAC FastAPI with an Elasticsearch backend
- STAC_FASTAPI_VERSION=2.1
- APP_HOST=0.0.0.0
- APP_PORT=8080
- RELOAD=true
- ENVIRONMENT=local
- WEB_CONCURRENCY=10
- ES_HOST=elasticsearch
- ES_PORT=9200
- ES_USE_SSL=false
- ES_VERIFY_CERTS=false
- BACKEND=elasticsearch
- STAC_FASTAPI_RATE_LIMIT=500/minute
ports:
- "8080:8080"
volumes:
- ./stac_fastapi:/app/stac_fastapi
- ./scripts:/app/scripts
- ./esdata:/usr/share/elasticsearch/data
depends_on:
- elasticsearch
command:
bash -c "./scripts/wait-for-it-es.sh es-container:9200 && python -m stac_fastapi.elasticsearch.app"

app-opensearch:
container_name: stac-fastapi-os
image: stac-utils/stac-fastapi-os
restart: always
build:
context: .
dockerfile: dockerfiles/Dockerfile.dev.os
environment:
- STAC_FASTAPI_TITLE=stac-fastapi-opensearch
- STAC_FASTAPI_DESCRIPTION=A STAC FastAPI with an Opensearch backend
- STAC_FASTAPI_VERSION=3.0.0a2
- APP_HOST=0.0.0.0
- APP_PORT=8082
- RELOAD=true
- ENVIRONMENT=local
- WEB_CONCURRENCY=10
- ES_HOST=opensearch
- ES_PORT=9202
- ES_USE_SSL=false
- ES_VERIFY_CERTS=false
- BACKEND=opensearch
- STAC_FASTAPI_RATE_LIMIT=200/minute
ports:
- "8082:8082"
volumes:
- ./stac_fastapi:/app/stac_fastapi
- ./scripts:/app/scripts
- ./osdata:/usr/share/opensearch/data
depends_on:
- opensearch
command:
bash -c "./scripts/wait-for-it-es.sh os-container:9202 && python -m stac_fastapi.opensearch.app"

elasticsearch:
container_name: es-container
image: docker.elastic.co/elasticsearch/elasticsearch:${ELASTICSEARCH_VERSION:-8.11.0}
hostname: elasticsearch
environment:
ES_JAVA_OPTS: -Xms512m -Xmx1g
volumes:
- ./elasticsearch/config/elasticsearch.yml:/usr/share/elasticsearch/config/elasticsearch.yml
- ./elasticsearch/snapshots:/usr/share/elasticsearch/snapshots
ports:
- "9200:9200"

opensearch:
container_name: os-container
image: opensearchproject/opensearch:${OPENSEARCH_VERSION:-2.11.1}
hostname: opensearch
environment:
- discovery.type=single-node
- plugins.security.disabled=true
- OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m
volumes:
- ./opensearch/config/opensearch.yml:/usr/share/opensearch/config/opensearch.yml
- ./opensearch/snapshots:/usr/share/opensearch/snapshots
ports:
- "9202:9202"
1 change: 1 addition & 0 deletions stac_fastapi/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"pygeofilter==0.2.1",
"typing_extensions==4.8.0",
"jsonschema",
"slowapi==0.1.9",
]

setup(
Expand Down
44 changes: 44 additions & 0 deletions stac_fastapi/core/stac_fastapi/core/rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Rate limiting middleware."""

import logging
import os
from typing import Optional

from fastapi import FastAPI, Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address

logger = logging.getLogger(__name__)


def get_limiter(key_func=get_remote_address):
"""Create and return a Limiter instance for rate limiting."""
return Limiter(key_func=key_func)


def setup_rate_limit(
app: FastAPI, rate_limit: Optional[str] = None, key_func=get_remote_address
):
"""Set up rate limiting middleware."""
RATE_LIMIT = rate_limit or os.getenv("STAC_FASTAPI_RATE_LIMIT")

if not RATE_LIMIT:
logger.info("Rate limiting is disabled")
return

logger.info(f"Setting up rate limit with RATE_LIMIT={RATE_LIMIT}")

limiter = get_limiter(key_func)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)

@app.middleware("http")
@limiter.limit(RATE_LIMIT)
async def rate_limit_middleware(request: Request, call_next):
response = await call_next(request)
return response

logger.info("Rate limit setup complete")
4 changes: 4 additions & 0 deletions stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EsAsyncAggregationClient,
)
from stac_fastapi.core.extensions.fields import FieldsExtension
from stac_fastapi.core.rate_limit import setup_rate_limit
from stac_fastapi.core.route_dependencies import get_route_dependencies
from stac_fastapi.core.session import Session
from stac_fastapi.elasticsearch.config import ElasticsearchSettings
Expand Down Expand Up @@ -97,6 +98,9 @@
app = api.app
app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "")

# Add rate limit
setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT"))


@app.on_event("startup")
async def _startup_event() -> None:
Expand Down
4 changes: 4 additions & 0 deletions stac_fastapi/opensearch/stac_fastapi/opensearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EsAsyncAggregationClient,
)
from stac_fastapi.core.extensions.fields import FieldsExtension
from stac_fastapi.core.rate_limit import setup_rate_limit
from stac_fastapi.core.route_dependencies import get_route_dependencies
from stac_fastapi.core.session import Session
from stac_fastapi.extensions.core import (
Expand Down Expand Up @@ -97,6 +98,9 @@
app = api.app
app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "")

# Add rate limit
setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT"))


@app.on_event("startup")
async def _startup_event() -> None:
Expand Down
4 changes: 2 additions & 2 deletions stac_fastapi/tests/basic_auth/test_basic_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def test_get_search_not_authenticated(app_client_basic_auth, ctx):

@pytest.mark.asyncio
async def test_post_search_authenticated(app_client_basic_auth, ctx):
"""Test protected endpoint [POST /search] with reader auhtentication"""
"""Test protected endpoint [POST /search] with reader authentication"""
if not os.getenv("BASIC_AUTH"):
pytest.skip()
params = {"id": ctx.item["id"]}
Expand All @@ -34,7 +34,7 @@ async def test_post_search_authenticated(app_client_basic_auth, ctx):
async def test_delete_resource_anonymous(
app_client_basic_auth,
):
"""Test protected endpoint [DELETE /collections/{collection_id}] without auhtentication"""
"""Test protected endpoint [DELETE /collections/{collection_id}] without authentication"""
if not os.getenv("BASIC_AUTH"):
pytest.skip()

Expand Down
60 changes: 60 additions & 0 deletions stac_fastapi/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
EsAggregationExtensionPostRequest,
EsAsyncAggregationClient,
)
from stac_fastapi.core.rate_limit import setup_rate_limit
from stac_fastapi.core.route_dependencies import get_route_dependencies

if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch":
Expand Down Expand Up @@ -246,6 +247,65 @@ async def app_client(app):
yield c


@pytest_asyncio.fixture(scope="session")
async def app_rate_limit():
settings = AsyncSettings()

aggregation_extension = AggregationExtension(
client=EsAsyncAggregationClient(
database=database, session=None, settings=settings
)
)
aggregation_extension.POST = EsAggregationExtensionPostRequest
aggregation_extension.GET = EsAggregationExtensionGetRequest

search_extensions = [
TransactionExtension(
client=TransactionsClient(
database=database, session=None, settings=settings
),
settings=settings,
),
SortExtension(),
FieldsExtension(),
QueryExtension(),
TokenPaginationExtension(),
FilterExtension(),
FreeTextExtension(),
]

extensions = [aggregation_extension] + search_extensions

post_request_model = create_post_request_model(search_extensions)

app = StacApi(
settings=settings,
client=CoreClient(
database=database,
session=None,
extensions=extensions,
post_request_model=post_request_model,
),
extensions=extensions,
search_get_request_model=create_get_request_model(search_extensions),
search_post_request_model=post_request_model,
).app

# Set up rate limit
setup_rate_limit(app, rate_limit="2/minute")

return app


@pytest_asyncio.fixture(scope="session")
async def app_client_rate_limit(app_rate_limit):
await create_index_templates()
await create_collection_index()

async with AsyncClient(app=app_rate_limit, base_url="http://test-server") as c:
yield c


@pytest_asyncio.fixture(scope="session")
async def app_basic_auth():

Expand Down
38 changes: 38 additions & 0 deletions stac_fastapi/tests/rate_limit/test_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

import pytest
from httpx import AsyncClient
from slowapi.errors import RateLimitExceeded

logger = logging.getLogger(__name__)


@pytest.mark.asyncio
async def test_rate_limit(app_client_rate_limit: AsyncClient, ctx):
expected_status_codes = [200, 200, 429, 429, 429]

for i, expected_status_code in enumerate(expected_status_codes):
try:
response = await app_client_rate_limit.get("/collections")
status_code = response.status_code
except RateLimitExceeded:
status_code = 429

logger.info(f"Request {i+1}: Status code {status_code}")
assert (
status_code == expected_status_code
), f"Expected status code {expected_status_code}, but got {status_code}"


@pytest.mark.asyncio
async def test_rate_limit_no_limit(app_client: AsyncClient, ctx):
expected_status_codes = [200, 200, 200, 200, 200]

for i, expected_status_code in enumerate(expected_status_codes):
response = await app_client.get("/collections")
status_code = response.status_code

logger.info(f"Request {i+1}: Status code {status_code}")
assert (
status_code == expected_status_code
), f"Expected status code {expected_status_code}, but got {status_code}"

0 comments on commit bbbba05

Please sign in to comment.