diff --git a/CHANGELOG.md b/CHANGELOG.md index e6fd124f..b96a2a5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added - Added `datetime_frequency_interval` parameter for `datetime_frequency` aggregation. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294) +- Added rate limiting functionality with configurable limits using environment variable `STAC_FASTAPI_RATE_LIMIT`, example: `500/minute`. [#303](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/303) ### Changed diff --git a/README.md b/README.md index d4b8754c..6ec94b72 100644 --- a/README.md +++ b/README.md @@ -383,4 +383,8 @@ Available aggregations are: - geometry_geohash_grid_frequency ([geohash grid](https://opensearch.org/docs/latest/aggregations/bucket/geohash-grid/) on Item.geometry) - geometry_geotile_grid_frequency ([geotile grid](https://opensearch.org/docs/latest/aggregations/bucket/geotile-grid/) on Item.geometry) -Support for additional fields and new aggregations can be added in the associated `database_logic.py` file. \ No newline at end of file +Support for additional fields and new aggregations can be added in the associated `database_logic.py` file. + +## Rate Limiting + +Rate limiting is an optional security feature that controls API request frequency on a remote address basis. It's enabled by setting the `STAC_FASTAPI_RATE_LIMIT` environment variable, e.g., `500/minute`. This limits each client to 500 requests per minute, helping prevent abuse and maintain API stability. Implementation examples are available in the [examples/rate_limit](examples/rate_limit) directory. \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 23455e2e..da4633b9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -54,6 +54,7 @@ services: - ES_USE_SSL=false - ES_VERIFY_CERTS=false - BACKEND=opensearch + - STAC_FASTAPI_RATE_LIMIT=200/minute ports: - "8082:8082" volumes: diff --git a/examples/rate_limit/docker-compose.rate_limit.yml b/examples/rate_limit/docker-compose.rate_limit.yml new file mode 100644 index 00000000..5416e139 --- /dev/null +++ b/examples/rate_limit/docker-compose.rate_limit.yml @@ -0,0 +1,94 @@ +version: '3.9' + +services: + app-elasticsearch: + container_name: stac-fastapi-es + image: stac-utils/stac-fastapi-es + restart: always + build: + context: . + dockerfile: dockerfiles/Dockerfile.dev.es + environment: + - STAC_FASTAPI_TITLE=stac-fastapi-elasticsearch + - STAC_FASTAPI_DESCRIPTION=A STAC FastAPI with an Elasticsearch backend + - STAC_FASTAPI_VERSION=2.1 + - APP_HOST=0.0.0.0 + - APP_PORT=8080 + - RELOAD=true + - ENVIRONMENT=local + - WEB_CONCURRENCY=10 + - ES_HOST=elasticsearch + - ES_PORT=9200 + - ES_USE_SSL=false + - ES_VERIFY_CERTS=false + - BACKEND=elasticsearch + - STAC_FASTAPI_RATE_LIMIT=500/minute + ports: + - "8080:8080" + volumes: + - ./stac_fastapi:/app/stac_fastapi + - ./scripts:/app/scripts + - ./esdata:/usr/share/elasticsearch/data + depends_on: + - elasticsearch + command: + bash -c "./scripts/wait-for-it-es.sh es-container:9200 && python -m stac_fastapi.elasticsearch.app" + + app-opensearch: + container_name: stac-fastapi-os + image: stac-utils/stac-fastapi-os + restart: always + build: + context: . + dockerfile: dockerfiles/Dockerfile.dev.os + environment: + - STAC_FASTAPI_TITLE=stac-fastapi-opensearch + - STAC_FASTAPI_DESCRIPTION=A STAC FastAPI with an Opensearch backend + - STAC_FASTAPI_VERSION=3.0.0a2 + - APP_HOST=0.0.0.0 + - APP_PORT=8082 + - RELOAD=true + - ENVIRONMENT=local + - WEB_CONCURRENCY=10 + - ES_HOST=opensearch + - ES_PORT=9202 + - ES_USE_SSL=false + - ES_VERIFY_CERTS=false + - BACKEND=opensearch + - STAC_FASTAPI_RATE_LIMIT=200/minute + ports: + - "8082:8082" + volumes: + - ./stac_fastapi:/app/stac_fastapi + - ./scripts:/app/scripts + - ./osdata:/usr/share/opensearch/data + depends_on: + - opensearch + command: + bash -c "./scripts/wait-for-it-es.sh os-container:9202 && python -m stac_fastapi.opensearch.app" + + elasticsearch: + container_name: es-container + image: docker.elastic.co/elasticsearch/elasticsearch:${ELASTICSEARCH_VERSION:-8.11.0} + hostname: elasticsearch + environment: + ES_JAVA_OPTS: -Xms512m -Xmx1g + volumes: + - ./elasticsearch/config/elasticsearch.yml:/usr/share/elasticsearch/config/elasticsearch.yml + - ./elasticsearch/snapshots:/usr/share/elasticsearch/snapshots + ports: + - "9200:9200" + + opensearch: + container_name: os-container + image: opensearchproject/opensearch:${OPENSEARCH_VERSION:-2.11.1} + hostname: opensearch + environment: + - discovery.type=single-node + - plugins.security.disabled=true + - OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m + volumes: + - ./opensearch/config/opensearch.yml:/usr/share/opensearch/config/opensearch.yml + - ./opensearch/snapshots:/usr/share/opensearch/snapshots + ports: + - "9202:9202" diff --git a/stac_fastapi/core/setup.py b/stac_fastapi/core/setup.py index ac16bfc8..9a359afc 100644 --- a/stac_fastapi/core/setup.py +++ b/stac_fastapi/core/setup.py @@ -19,6 +19,7 @@ "pygeofilter==0.2.1", "typing_extensions==4.8.0", "jsonschema", + "slowapi==0.1.9", ] setup( diff --git a/stac_fastapi/core/stac_fastapi/core/rate_limit.py b/stac_fastapi/core/stac_fastapi/core/rate_limit.py new file mode 100644 index 00000000..3c90f73f --- /dev/null +++ b/stac_fastapi/core/stac_fastapi/core/rate_limit.py @@ -0,0 +1,44 @@ +"""Rate limiting middleware.""" + +import logging +import os +from typing import Optional + +from fastapi import FastAPI, Request +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.middleware import SlowAPIMiddleware +from slowapi.util import get_remote_address + +logger = logging.getLogger(__name__) + + +def get_limiter(key_func=get_remote_address): + """Create and return a Limiter instance for rate limiting.""" + return Limiter(key_func=key_func) + + +def setup_rate_limit( + app: FastAPI, rate_limit: Optional[str] = None, key_func=get_remote_address +): + """Set up rate limiting middleware.""" + RATE_LIMIT = rate_limit or os.getenv("STAC_FASTAPI_RATE_LIMIT") + + if not RATE_LIMIT: + logger.info("Rate limiting is disabled") + return + + logger.info(f"Setting up rate limit with RATE_LIMIT={RATE_LIMIT}") + + limiter = get_limiter(key_func) + app.state.limiter = limiter + app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + app.add_middleware(SlowAPIMiddleware) + + @app.middleware("http") + @limiter.limit(RATE_LIMIT) + async def rate_limit_middleware(request: Request, call_next): + response = await call_next(request) + return response + + logger.info("Rate limit setup complete") diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py index 6b26c2ac..5e6307e7 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py @@ -17,6 +17,7 @@ EsAsyncAggregationClient, ) from stac_fastapi.core.extensions.fields import FieldsExtension +from stac_fastapi.core.rate_limit import setup_rate_limit from stac_fastapi.core.route_dependencies import get_route_dependencies from stac_fastapi.core.session import Session from stac_fastapi.elasticsearch.config import ElasticsearchSettings @@ -97,6 +98,9 @@ app = api.app app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "") +# Add rate limit +setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT")) + @app.on_event("startup") async def _startup_event() -> None: diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py index 2a764518..8be0eafd 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py @@ -17,6 +17,7 @@ EsAsyncAggregationClient, ) from stac_fastapi.core.extensions.fields import FieldsExtension +from stac_fastapi.core.rate_limit import setup_rate_limit from stac_fastapi.core.route_dependencies import get_route_dependencies from stac_fastapi.core.session import Session from stac_fastapi.extensions.core import ( @@ -97,6 +98,9 @@ app = api.app app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "") +# Add rate limit +setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT")) + @app.on_event("startup") async def _startup_event() -> None: diff --git a/stac_fastapi/tests/basic_auth/test_basic_auth.py b/stac_fastapi/tests/basic_auth/test_basic_auth.py index 95be59ee..11167fd9 100644 --- a/stac_fastapi/tests/basic_auth/test_basic_auth.py +++ b/stac_fastapi/tests/basic_auth/test_basic_auth.py @@ -18,7 +18,7 @@ async def test_get_search_not_authenticated(app_client_basic_auth, ctx): @pytest.mark.asyncio async def test_post_search_authenticated(app_client_basic_auth, ctx): - """Test protected endpoint [POST /search] with reader auhtentication""" + """Test protected endpoint [POST /search] with reader authentication""" if not os.getenv("BASIC_AUTH"): pytest.skip() params = {"id": ctx.item["id"]} @@ -34,7 +34,7 @@ async def test_post_search_authenticated(app_client_basic_auth, ctx): async def test_delete_resource_anonymous( app_client_basic_auth, ): - """Test protected endpoint [DELETE /collections/{collection_id}] without auhtentication""" + """Test protected endpoint [DELETE /collections/{collection_id}] without authentication""" if not os.getenv("BASIC_AUTH"): pytest.skip() diff --git a/stac_fastapi/tests/conftest.py b/stac_fastapi/tests/conftest.py index ca2d8436..651cdadb 100644 --- a/stac_fastapi/tests/conftest.py +++ b/stac_fastapi/tests/conftest.py @@ -24,6 +24,7 @@ EsAggregationExtensionPostRequest, EsAsyncAggregationClient, ) +from stac_fastapi.core.rate_limit import setup_rate_limit from stac_fastapi.core.route_dependencies import get_route_dependencies if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch": @@ -246,6 +247,65 @@ async def app_client(app): yield c +@pytest_asyncio.fixture(scope="session") +async def app_rate_limit(): + settings = AsyncSettings() + + aggregation_extension = AggregationExtension( + client=EsAsyncAggregationClient( + database=database, session=None, settings=settings + ) + ) + aggregation_extension.POST = EsAggregationExtensionPostRequest + aggregation_extension.GET = EsAggregationExtensionGetRequest + + search_extensions = [ + TransactionExtension( + client=TransactionsClient( + database=database, session=None, settings=settings + ), + settings=settings, + ), + SortExtension(), + FieldsExtension(), + QueryExtension(), + TokenPaginationExtension(), + FilterExtension(), + FreeTextExtension(), + ] + + extensions = [aggregation_extension] + search_extensions + + post_request_model = create_post_request_model(search_extensions) + + app = StacApi( + settings=settings, + client=CoreClient( + database=database, + session=None, + extensions=extensions, + post_request_model=post_request_model, + ), + extensions=extensions, + search_get_request_model=create_get_request_model(search_extensions), + search_post_request_model=post_request_model, + ).app + + # Set up rate limit + setup_rate_limit(app, rate_limit="2/minute") + + return app + + +@pytest_asyncio.fixture(scope="session") +async def app_client_rate_limit(app_rate_limit): + await create_index_templates() + await create_collection_index() + + async with AsyncClient(app=app_rate_limit, base_url="http://test-server") as c: + yield c + + @pytest_asyncio.fixture(scope="session") async def app_basic_auth(): diff --git a/stac_fastapi/tests/rate_limit/test_rate_limit.py b/stac_fastapi/tests/rate_limit/test_rate_limit.py new file mode 100644 index 00000000..fd6b5bce --- /dev/null +++ b/stac_fastapi/tests/rate_limit/test_rate_limit.py @@ -0,0 +1,38 @@ +import logging + +import pytest +from httpx import AsyncClient +from slowapi.errors import RateLimitExceeded + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_rate_limit(app_client_rate_limit: AsyncClient, ctx): + expected_status_codes = [200, 200, 429, 429, 429] + + for i, expected_status_code in enumerate(expected_status_codes): + try: + response = await app_client_rate_limit.get("/collections") + status_code = response.status_code + except RateLimitExceeded: + status_code = 429 + + logger.info(f"Request {i+1}: Status code {status_code}") + assert ( + status_code == expected_status_code + ), f"Expected status code {expected_status_code}, but got {status_code}" + + +@pytest.mark.asyncio +async def test_rate_limit_no_limit(app_client: AsyncClient, ctx): + expected_status_codes = [200, 200, 200, 200, 200] + + for i, expected_status_code in enumerate(expected_status_codes): + response = await app_client.get("/collections") + status_code = response.status_code + + logger.info(f"Request {i+1}: Status code {status_code}") + assert ( + status_code == expected_status_code + ), f"Expected status code {expected_status_code}, but got {status_code}"