From 7ca5b00091b0f1d7de28d7496344e1449fcfa3f9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9mi?= <remi.cresson@inrae.fr>
Date: Wed, 2 Oct 2024 11:13:06 +0200
Subject: [PATCH] enh: support collections

---
 stac_extension_genmeta/__init__.py |  2 +-
 stac_extension_genmeta/core.py     | 17 ++++++++++++-----
 stac_extension_genmeta/testing.py  | 23 ++++++++++++++++++++---
 3 files changed, 33 insertions(+), 9 deletions(-)

diff --git a/stac_extension_genmeta/__init__.py b/stac_extension_genmeta/__init__.py
index ea36f6d..b50774d 100644
--- a/stac_extension_genmeta/__init__.py
+++ b/stac_extension_genmeta/__init__.py
@@ -1,2 +1,2 @@
 from .core import create_extension_cls
-__version__ = "0.0.21"
\ No newline at end of file
+__version__ = "0.0.22"
\ No newline at end of file
diff --git a/stac_extension_genmeta/core.py b/stac_extension_genmeta/core.py
index ae7a88a..e4aba2b 100644
--- a/stac_extension_genmeta/core.py
+++ b/stac_extension_genmeta/core.py
@@ -44,7 +44,7 @@ def create_extension_cls(
         def __init__(self, obj: T):
             if isinstance(obj, pystac.Item):
                 self.properties = obj.properties
-            elif isinstance(obj, pystac.Asset):
+            elif isinstance(obj, (pystac.Asset, pystac.Collection)):
                 self.properties = obj.extra_fields
             else:
                 raise pystac.ExtensionTypeError(
@@ -109,12 +109,12 @@ def create_extension_cls(
         ) -> model_cls.__name__:
             if isinstance(obj, pystac.Item):
                 cls.ensure_has_extension(obj, add_if_missing)
-                return cast(CustomExtension[T],
-                            ItemCustomExtension(obj))
+                return cast(CustomExtension[T], ItemCustomExtension(obj))
             elif isinstance(obj, pystac.Asset):
                 cls.ensure_owner_has_extension(obj, add_if_missing)
-                return cast(CustomExtension[T],
-                            AssetCustomExtension(obj))
+                return cast(CustomExtension[T], AssetCustomExtension(obj))
+            elif isinstance(obj, pystac.Collection):
+                return cast(CustomExtension[T], CollectionCustomExtension(obj))
             raise pystac.ExtensionTypeError(
                 f"{model_cls.__name__} does not apply to type "
                 f"{type(obj).__name__}"
@@ -134,5 +134,12 @@ def create_extension_cls(
             if asset.owner and isinstance(asset.owner, pystac.Item):
                 self.additional_read_properties = [asset.owner.properties]
 
+    class CollectionCustomExtension(CustomExtension[pystac.Collection]):
+        properties: dict[str, Any]
+        additional_read_properties: Iterable[dict[str, Any]] | None = None
+
+        def __init__(self, collection: pystac.Collection):
+            self.properties = collection.extra_fields
+
     CustomExtension.__name__ = f"CustomExtensionFrom{model_cls.__name__}"
     return CustomExtension
diff --git a/stac_extension_genmeta/testing.py b/stac_extension_genmeta/testing.py
index 518e063..24ecd44 100644
--- a/stac_extension_genmeta/testing.py
+++ b/stac_extension_genmeta/testing.py
@@ -50,7 +50,7 @@ def create_dummy_item(date=None):
     )
     col.add_item(item)
 
-    return item
+    return item, col
 
 
 def basic_test(
@@ -58,6 +58,7 @@ def basic_test(
         ext_cls,
         item_test: bool = True,
         asset_test: bool = True,
+        collection_test: bool = True,
         validate: bool = True
 ):
     print(
@@ -92,7 +93,7 @@ def basic_test(
         """
         Test extension against item
         """
-        item = create_dummy_item()
+        item, _ = create_dummy_item()
         apply(item)
         print_item(item)
         if validate:
@@ -104,7 +105,7 @@ def basic_test(
         """
         Test extension against asset
         """
-        item = create_dummy_item()
+        item, _ = create_dummy_item()
         apply(item.assets["ndvi"])
         print_item(item)
         if validate:
@@ -112,12 +113,28 @@ def basic_test(
         # Check that we can retrieve the extension metadata from the asset
         comp(item.assets["ndvi"])
 
+    def test_collection():
+        """
+        Test extension against collection
+        """
+        item, col = create_dummy_item()
+        print_item(col)
+        apply(col)
+        print_item(col)
+        if validate:
+            col.validate()  # <--- This will try to read the actual schema URI
+        # Check that we can retrieve the extension metadata from the asset
+        comp(col)
+
     if item_test:
         print("Test item")
         test_item()
     if asset_test:
         print("Test asset")
         test_asset()
+    if collection_test:
+        print("Test collection")
+        test_collection()
 
 
 def is_schema_url_synced(cls):
-- 
GitLab