diff --git a/lektor_groupby/backref.py b/lektor_groupby/backref.py index 6128782..a0a17ca 100644 --- a/lektor_groupby/backref.py +++ b/lektor_groupby/backref.py @@ -1,5 +1,5 @@ from lektor.context import get_ctx -from typing import TYPE_CHECKING, Union, Iterable, Iterator +from typing import TYPE_CHECKING, Union, Iterable, Iterator, Optional import weakref if TYPE_CHECKING: from lektor.builder import Builder @@ -47,7 +47,8 @@ class VGroups: *, fields: Union[str, Iterable[str], None] = None, flows: Union[str, Iterable[str], None] = None, - recursive: bool = False + recursive: bool = False, + order_by: Optional[str] = None, ) -> Iterator['GroupBySource']: ''' Extract all referencing groupby virtual objects from a page. ''' ctx = get_ctx() @@ -69,6 +70,7 @@ class VGroups: flows = [flows] # find groups proc_list = [record] + done_list = set() while proc_list: page = proc_list.pop(0) if recursive and hasattr(page, 'children'): @@ -80,4 +82,10 @@ class VGroups: continue if keys and vobj().config.key not in keys: continue - yield vobj() + done_list.add(vobj()) + + if order_by: + order = order_by.split(',') + yield from sorted(done_list, key=lambda x: x.get_sort_key(order)) + else: + yield from done_list diff --git a/lektor_groupby/vobj.py b/lektor_groupby/vobj.py index 5d07c00..5144d33 100644 --- a/lektor_groupby/vobj.py +++ b/lektor_groupby/vobj.py @@ -1,9 +1,10 @@ from lektor.build_programs import BuildProgram # subclass from lektor.context import get_ctx +from lektor.db import _CmpHelper from lektor.environment import Expression from lektor.sourceobj import VirtualSourceObject # subclass from lektor.utils import build_url -from typing import TYPE_CHECKING, Dict, List, Any, Optional, Iterator +from typing import TYPE_CHECKING, Dict, List, Any, Optional, Iterator, Iterable from .util import report_config_error, most_used_key if TYPE_CHECKING: from lektor.builder import Artifact @@ -93,6 +94,15 @@ class GroupBySource(VirtualSourceObject): for record in self._children: yield from record.iter_source_filenames() + def get_sort_key(self, fields: Iterable[str]) -> List: + def cmp_val(field: str) -> Any: + reverse = field.startswith('-') + if reverse or field.startswith('+'): + field = field[1:] + return _CmpHelper(getattr(self, field, None), reverse) + + return [cmp_val(field) for field in fields] + # ----------------------- # Properties & Helper # -----------------------