diff --git a/src/pypgstac/python/pypgstac/load.py b/src/pypgstac/python/pypgstac/load.py index 2e9bcc6..b7845a4 100644 --- a/src/pypgstac/python/pypgstac/load.py +++ b/src/pypgstac/python/pypgstac/load.py @@ -270,6 +270,7 @@ def load_partition( partition: Partition, items: Iterable[Dict[str, Any]], insert_mode: Optional[Methods] = Methods.insert, + partition_update_enabled: Optional[bool] = True, ) -> None: """Load items data for a single partition.""" conn = self.db.connect() @@ -441,15 +442,20 @@ def load_partition( "Available modes are insert, ignore, upsert, and delsert." f"You entered {insert_mode}.", ) - logger.debug("Updating Partition Stats") - cur.execute("SELECT update_partition_stats_q(%s);",(partition.name,)) - logger.debug(cur.statusmessage) - logger.debug(f"Rows affected: {cur.rowcount}") + if partition_update_enabled: + logger.debug("Updating Partition Stats") + cur.execute("SELECT update_partition_stats_q(%s);",(partition.name,)) + logger.debug(cur.statusmessage) + logger.debug(f"Rows affected: {cur.rowcount}") logger.debug( f"Copying data for {partition} took {time.perf_counter() - t} seconds", ) - def _partition_update(self, item: Dict[str, Any]) -> str: + def _partition_update( + self, + item: Dict[str, Any], + update_enabled: Optional[bool] = True, + ) -> str: """Update the cached partition with the item information and return the name. This method will mark the partition as dirty if the bounds of the partition @@ -515,20 +521,24 @@ def _partition_update(self, item: Dict[str, Any]) -> str: partition = self._partition_cache[partition_name] if partition: - # Only update the partition if the item is outside the current bounds - if item["datetime"] < partition.datetime_range_min: - partition.datetime_range_min = item["datetime"] - partition.requires_update = True - if item["datetime"] > partition.datetime_range_max: - partition.datetime_range_max = item["datetime"] - partition.requires_update = True - if item["end_datetime"] < partition.end_datetime_range_min: - partition.end_datetime_range_min = item["end_datetime"] - partition.requires_update = True - if item["end_datetime"] > partition.end_datetime_range_max: - partition.end_datetime_range_max = item["end_datetime"] - partition.requires_update = True + if update_enabled: + # Only update the partition if the item is outside the current bounds + if item["datetime"] < partition.datetime_range_min: + partition.datetime_range_min = item["datetime"] + partition.requires_update = True + if item["datetime"] > partition.datetime_range_max: + partition.datetime_range_max = item["datetime"] + partition.requires_update = True + if item["end_datetime"] < partition.end_datetime_range_min: + partition.end_datetime_range_min = item["end_datetime"] + partition.requires_update = True + if item["end_datetime"] > partition.end_datetime_range_max: + partition.end_datetime_range_max = item["end_datetime"] + partition.requires_update = True else: + if not update_enabled: + raise Exception(f"Partition {partition_name} does not exist.") + # No partition exists yet; create a new one from item partition = Partition( name=partition_name, @@ -544,7 +554,11 @@ def _partition_update(self, item: Dict[str, Any]) -> str: return partition_name - def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator: + def read_dehydrated( + self, + file: Union[Path, str] = "stdin", + partition_update_enabled: Optional[bool] = True, + ) -> Generator: if file is None: file = "stdin" if isinstance(file, str): @@ -575,15 +589,21 @@ def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator: item[field] = content_value else: item[field] = tab_split[i] - item["partition"] = self._partition_update(item) + item["partition"] = self._partition_update( + item, + partition_update_enabled, + ) yield item def read_hydrated( - self, file: Union[Path, str, Iterator[Any]] = "stdin", + self, + file: Union[Path, str, + Iterator[Any]] = "stdin", + partition_update_enabled: Optional[bool] = True, ) -> Generator: for line in read_json(file): item = self.format_item(line) - item["partition"] = self._partition_update(item) + item["partition"] = self._partition_update(item, partition_update_enabled) yield item def load_items( @@ -592,6 +612,7 @@ def load_items( insert_mode: Optional[Methods] = Methods.insert, dehydrated: Optional[bool] = False, chunksize: Optional[int] = 10000, + partition_update_enabled: Optional[bool] = True, ) -> None: """Load items json records.""" self.check_version() @@ -602,15 +623,17 @@ def load_items( self._partition_cache = {} if dehydrated and isinstance(file, str): - items = self.read_dehydrated(file) + items = self.read_dehydrated(file, partition_update_enabled) else: - items = self.read_hydrated(file) + items = self.read_hydrated(file, partition_update_enabled) for chunkin in chunked_iterable(items, chunksize): chunk = list(chunkin) chunk.sort(key=lambda x: x["partition"]) for k, g in itertools.groupby(chunk, lambda x: x["partition"]): - self.load_partition(self._partition_cache[k], g, insert_mode) + self.load_partition( + self._partition_cache[k], g, insert_mode, partition_update_enabled, + ) logger.debug(f"Adding data to database took {time.perf_counter() - t} seconds.") diff --git a/src/pypgstac/python/pypgstac/pypgstac.py b/src/pypgstac/python/pypgstac/pypgstac.py index 1cfbdb3..abab27d 100644 --- a/src/pypgstac/python/pypgstac/pypgstac.py +++ b/src/pypgstac/python/pypgstac/pypgstac.py @@ -67,13 +67,16 @@ def load( method: Optional[Methods] = Methods.insert, dehydrated: Optional[bool] = False, chunksize: Optional[int] = 10000, + partition_update_enabled: Optional[bool] = True, ) -> None: """Load collections or items into PGStac.""" loader = Loader(db=self._db) if table == "collections": loader.load_collections(file, method) if table == "items": - loader.load_items(file, method, dehydrated, chunksize) + loader.load_items( + file, method, dehydrated, chunksize, partition_update_enabled, + ) def runqueue(self) -> str: return self._db.run_queued() diff --git a/src/pypgstac/tests/test_load.py b/src/pypgstac/tests/test_load.py index 5500663..47446b6 100644 --- a/src/pypgstac/tests/test_load.py +++ b/src/pypgstac/tests/test_load.py @@ -449,3 +449,20 @@ def test_load_items_nopartitionconstraint_succeeds(loader: Loader) -> None: """, ) assert cdtmin == "2011-07-31 00:00:00+00" + + +def test_load_items_when_partition_creation_disabled(loader: Loader) -> None: + """ + Test pypgstac items loader raises an exception when partition + does not exist and partition creation is disabled. + """ + loader.load_collections( + str(TEST_COLLECTIONS_JSON), + insert_mode=Methods.insert, + ) + with pytest.raises(ValueError): + loader.load_items( + str(TEST_ITEMS), + insert_mode=Methods.insert, + partition_update_enabled=False, + )