diff --git a/lektor_groupby/backref.py b/lektor_groupby/backref.py index 892f95f..6128782 100644 --- a/lektor_groupby/backref.py +++ b/lektor_groupby/backref.py @@ -1,13 +1,19 @@ from lektor.context import get_ctx -from typing import TYPE_CHECKING, Iterator -from weakref import WeakSet +from typing import TYPE_CHECKING, Union, Iterable, Iterator +import weakref if TYPE_CHECKING: from lektor.builder import Builder from lektor.db import Record from .groupby import GroupBy + from .model import FieldKeyPath from .vobj import GroupBySource +class WeakVGroupsList(list): + def add(self, strong: 'FieldKeyPath', weak: 'GroupBySource') -> None: + super().append((strong, weakref.ref(weak))) + + class GroupByRef: @staticmethod def of(builder: 'Builder') -> 'GroupBy': @@ -22,7 +28,7 @@ class GroupByRef: class VGroups: @staticmethod - def of(record: 'Record') -> WeakSet: + def of(record: 'Record') -> WeakVGroupsList: ''' Return the (weak) set of virtual objects of a page. Creates a new set if it does not exist yet. @@ -30,13 +36,19 @@ class VGroups: try: wset = record.__vgroups # type: ignore[attr-defined] except AttributeError: - wset = WeakSet() + wset = WeakVGroupsList() record.__vgroups = wset # type: ignore[attr-defined] return wset # type: ignore[no-any-return] @staticmethod - def iter(record: 'Record', *keys: str, recursive: bool = False) \ - -> Iterator['GroupBySource']: + def iter( + record: 'Record', + keys: Union[str, Iterable[str], None] = None, + *, + fields: Union[str, Iterable[str], None] = None, + flows: Union[str, Iterable[str], None] = None, + recursive: bool = False + ) -> Iterator['GroupBySource']: ''' Extract all referencing groupby virtual objects from a page. ''' ctx = get_ctx() if not ctx: @@ -48,12 +60,24 @@ class VGroups: # manage config dependencies for dep in groupby.dependencies: ctx.record_dependency(dep) + # prepare filter + if isinstance(keys, str): + keys = [keys] + if isinstance(fields, str): + fields = [fields] + if isinstance(flows, str): + flows = [flows] # find groups proc_list = [record] while proc_list: page = proc_list.pop(0) if recursive and hasattr(page, 'children'): proc_list.extend(page.children) # type: ignore[attr-defined] - for vobj in VGroups.of(page): - if not keys or vobj.config.key in keys: - yield vobj + for key, vobj in VGroups.of(page): + if fields and key.fieldKey not in fields: + continue + if flows and key.flowKey not in flows: + continue + if keys and vobj().config.key not in keys: + continue + yield vobj() diff --git a/lektor_groupby/groupby.py b/lektor_groupby/groupby.py index 8d01bb7..372ccd8 100644 --- a/lektor_groupby/groupby.py +++ b/lektor_groupby/groupby.py @@ -44,7 +44,7 @@ class GroupBy: return # initialize remaining (enabled) watchers for w in self._watcher: - w.initialize(builder.pad.db) + w.initialize(builder.pad) # iterate over whole build tree queue = builder.pad.get_all_roots() # type: List[SourceObject] while queue: diff --git a/lektor_groupby/resolver.py b/lektor_groupby/resolver.py index 98b7162..719f8ed 100644 --- a/lektor_groupby/resolver.py +++ b/lektor_groupby/resolver.py @@ -15,7 +15,7 @@ class Resolver: ''' def __init__(self, env: 'Environment') -> None: - self._data = {} # type: Dict[str, Tuple[str, Config]] + self._data = {} # type: Dict[str, Tuple[str, str, Config]] env.urlresolver(self.resolve_server_path) env.virtualpathresolver(VPATH.lstrip('@'))(self.resolve_virtual_path) @@ -34,7 +34,7 @@ class Resolver: def add(self, vobj: GroupBySource) -> None: ''' Track new virtual object (only if slug is set). ''' if vobj.slug: - self._data[vobj.url_path] = (vobj.group, vobj.config) + self._data[vobj.url_path] = (vobj.key, vobj.group, vobj.config) # ------------ # Resolver @@ -46,7 +46,7 @@ class Resolver: if isinstance(node, Record): rv = self._data.get(build_url([node.url_path] + pieces)) if rv: - return GroupBySource(node, group=rv[0], config=rv[1]) + return GroupBySource(node, rv[0]).finalize(rv[2], rv[1]) return None def resolve_virtual_path(self, node: 'SourceObject', pieces: List[str]) \ @@ -54,9 +54,8 @@ class Resolver: ''' Admin UI only: Prevent server error and null-redirect. ''' if isinstance(node, Record) and len(pieces) >= 2: path = node['_path'] # type: str - key, grp, *_ = pieces - for group, conf in self._data.values(): - if key == conf.key and path == conf.root: - if conf.slugify(group) == grp: - return GroupBySource(node, group, conf) + attr, grp, *_ = pieces + for slug, group, conf in self._data.values(): + if attr == conf.key and slug == grp and path == conf.root: + return GroupBySource(node, slug).finalize(conf, group) return None diff --git a/lektor_groupby/vobj.py b/lektor_groupby/vobj.py index d9cb38a..05597d0 100644 --- a/lektor_groupby/vobj.py +++ b/lektor_groupby/vobj.py @@ -4,8 +4,7 @@ from lektor.environment import Expression from lektor.sourceobj import VirtualSourceObject # subclass from lektor.utils import build_url from typing import TYPE_CHECKING, Dict, List, Any, Optional, Iterator -from .backref import VGroups -from .util import report_config_error +from .util import report_config_error, most_used_key if TYPE_CHECKING: from lektor.builder import Artifact from lektor.db import Record @@ -25,18 +24,29 @@ class GroupBySource(VirtualSourceObject): Attributes: record, key, group, slug, children, config ''' - def __init__( - self, - record: 'Record', - group: str, - config: 'Config', - children: Optional[Dict['Record', List[Any]]] = None, - ) -> None: + def __init__(self, record: 'Record', slug: str) -> None: super().__init__(record) - self.key = config.slugify(group) - self.group = group + self.key = slug + self._group_map = [] # type: List[str] + self._children = {} # type: Dict[Record, List[Any]] + + def append_child(self, child: 'Record', extra: Any, group: str) -> None: + if child not in self._children: + self._children[child] = [extra] + else: + self._children[child].append(extra) + # _group_map is later used to find most used group + self._group_map.append(group) + + # ------------------------- + # Evaluate Extra Fields + # ------------------------- + + def finalize(self, config: 'Config', group: Optional[str] = None) \ + -> 'GroupBySource': self.config = config - self._children = children or {} # type: Dict[Record, List[Any]] + self.group = group or most_used_key(self._group_map) + del self._group_map # evaluate slug Expression if config.slug and '{key}' in config.slug: self.slug = config.slug.replace('{key}', self.key) @@ -48,9 +58,7 @@ class GroupBySource(VirtualSourceObject): # extra fields for attr, expr in config.fields.items(): setattr(self, attr, self._eval(expr, field='fields.' + attr)) - # back-ref - for child in self._children: - VGroups.of(child).add(self) + return self def _eval(self, value: Any, *, field: str) -> Any: ''' Internal only: evaluates Lektor config file field expression. ''' diff --git a/lektor_groupby/watcher.py b/lektor_groupby/watcher.py index dc01d60..e3c3fd7 100644 --- a/lektor_groupby/watcher.py +++ b/lektor_groupby/watcher.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING, Dict, List, Tuple, Any, Union, NamedTuple from typing import Optional, Callable, Iterator, Generator +from .backref import VGroups from .model import ModelReader -from .util import most_used_key from .vobj import GroupBySource if TYPE_CHECKING: - from lektor.db import Database, Record + from lektor.db import Pad, Record from .config import Config from .model import FieldKeyPath @@ -44,12 +44,12 @@ class Watcher: self.callback = fn return _decorator - def initialize(self, db: 'Database') -> None: + def initialize(self, pad: 'Pad') -> None: ''' Reset internal state. You must initialize before each build! ''' assert callable(self.callback), 'No grouping callback provided.' - self._model_reader = ModelReader(db, self.config.key, self.flatten) - self._state = {} # type: Dict[str, Dict[Record, List[Any]]] - self._group_map = {} # type: Dict[str, List[str]] + self._model_reader = ModelReader(pad.db, self.config.key, self.flatten) + self._root_record = pad.get(self._root) # type: Record + self._state = {} # type: Dict[str, GroupBySource] def should_process(self, node: 'Record') -> bool: ''' Check if record path is being watched. ''' @@ -77,39 +77,34 @@ class Watcher: del _gen def _persist( - self, - record: 'Record', - key: 'FieldKeyPath', - obj: Union[str, tuple] + self, record: 'Record', key: 'FieldKeyPath', obj: Union[str, tuple] ) -> str: ''' Update internal state. Return slugified string. ''' - group = obj if isinstance(obj, str) else obj[0] - slug = self.config.slugify(group) - # init group-key - if slug not in self._state: - self._state[slug] = {} - self._group_map[slug] = [] - # _group_map is later used to find most used group - self._group_map[slug].append(group) - # init group extras - if record not in self._state[slug]: - self._state[slug][record] = [] - # append extras (or default value) - if isinstance(obj, tuple): - self._state[slug][record].append(obj[1]) + if isinstance(obj, str): + group, extra = obj, key.fieldKey else: - self._state[slug][record].append(key.fieldKey) + group, extra = obj + + slug = self.config.slugify(group) + if slug not in self._state: + src = GroupBySource(self._root_record, slug) + self._state[slug] = src + else: + src = self._state[slug] + + src.append_child(record, extra, group) + # reverse reference + VGroups.of(record).add(key, src) return slug def iter_sources(self, root: 'Record') -> Iterator[GroupBySource]: ''' Prepare and yield GroupBySource elements. ''' - for key, children in self._state.items(): - group = most_used_key(self._group_map[key]) - yield GroupBySource(root, group, self.config, children=children) + for vobj in self._state.values(): + yield vobj.finalize(self.config) # cleanup. remove this code if you'd like to iter twice del self._model_reader + del self._root_record del self._state - del self._group_map def __repr__(self) -> str: return ''.format(