diff --git a/lektor_groupby/plugin.py b/lektor_groupby/plugin.py index 9f92f9b..9a7a11d 100644 --- a/lektor_groupby/plugin.py +++ b/lektor_groupby/plugin.py @@ -1,9 +1,10 @@ from lektor.builder import Builder # typing +from lektor.db import Page # typing from lektor.pluginsystem import Plugin # subclass from lektor.sourceobj import SourceObject # typing -from typing import List, Optional, Iterator -from .vobj import GroupBySource, GroupByBuildProgram, VPATH +from typing import List, Optional, Iterator, Any +from .vobj import GroupBySource, GroupByBuildProgram, VPATH, VGroups from .groupby import GroupBy from .pruner import prune from .watcher import GroupByCallbackArgs # typing @@ -13,10 +14,10 @@ class GroupByPlugin(Plugin): name = 'GroupBy Plugin' description = 'Cluster arbitrary records with field attribute keyword.' - def on_setup_env(self, **extra: object) -> None: + def on_setup_env(self, **extra: Any) -> None: self.creator = GroupBy() self.env.add_build_program(GroupBySource, GroupByBuildProgram) - self.env.jinja_env.filters.update(vgroups=GroupBySource.of_record) + self.env.jinja_env.filters.update(vgroups=VGroups.iter) # resolve /tag/rss/ -> /tag/rss/index.html (local server only) @self.env.urlresolver @@ -46,7 +47,7 @@ class GroupByPlugin(Plugin): if isinstance(val, (list, map)): yield from val - def on_before_build_all(self, builder: Builder, **extra: object) -> None: + def on_before_build_all(self, builder: Builder, **extra: Any) -> None: self.creator.clear_previous_results() self._load_quick_config() # let other plugins register their @groupby.watch functions @@ -54,14 +55,15 @@ class GroupByPlugin(Plugin): self.config_dependencies = self.creator.get_dependencies() self.creator.make_cluster(builder) - def on_before_build(self, source: SourceObject, **extra: object) -> None: + def on_before_build(self, source: SourceObject, **extra: Any) -> None: # before-build may be called before before-build-all (issue #1017) # make sure it is evaluated immediatelly - self.creator.queue_now(source) + if isinstance(source, Page): + self.creator.queue_now(source) - def on_after_build_all(self, builder: Builder, **extra: object) -> None: + def on_after_build_all(self, builder: Builder, **extra: Any) -> None: self.creator.build_all(builder) - def on_after_prune(self, builder: Builder, **extra: object) -> None: + def on_after_prune(self, builder: Builder, **extra: Any) -> None: # TODO: find a better way to prune unreferenced elements prune(builder, VPATH) diff --git a/lektor_groupby/vobj.py b/lektor_groupby/vobj.py index 17f097f..02d97ff 100644 --- a/lektor_groupby/vobj.py +++ b/lektor_groupby/vobj.py @@ -51,7 +51,7 @@ class GroupBySource(VirtualSourceObject): if child.pad != record.pad: child = record.pad.get(child.path) self._children[child] = extras - self._reverse_reference_records() + VGroups.of(child).add(self) # extra fields for attr, expr in config.fields.items(): setattr(self, attr, self._eval(expr, field='fields.' + attr)) @@ -138,19 +138,23 @@ class GroupBySource(VirtualSourceObject): return ''.format( self.path, len(self._children)) - # --------------------- - # Reverse Reference - # --------------------- - def _reverse_reference_records(self) -> None: - ''' Attach self to page records. ''' - for child in self._children: - if not hasattr(child, '_vgroups'): - child._vgroups = WeakSet() # type: ignore[attr-defined] - child._vgroups.add(self) # type: ignore[attr-defined] +# ----------------------------------- +# Reverse Reference +# ----------------------------------- + +class VGroups: + @staticmethod + def of(record: Record) -> WeakSet: + try: + wset = record.__vgroups # type: ignore[attr-defined] + except AttributeError: + wset = WeakSet() + record.__vgroups = wset # type: ignore[attr-defined] + return wset # type: ignore[no-any-return] @staticmethod - def of_record( + def iter( record: Record, *keys: str, recursive: bool = False @@ -167,9 +171,8 @@ class GroupBySource(VirtualSourceObject): page = proc_list.pop(0) if recursive and hasattr(page, 'children'): proc_list.extend(page.children) # type: ignore[attr-defined] - if not hasattr(page, '_vgroups'): - continue - for vobj in page._vgroups: # type: ignore[attr-defined] + for vobj in VGroups.of(page): + vobj.config.dependencies if not keys or vobj.config.key in keys: yield vobj