fix: build queue and dependencies + add key_map_fn

This commit is contained in:
relikd
2022-11-22 10:58:14 +01:00
parent e7ae59fadf
commit 0891be06e2
9 changed files with 300 additions and 134 deletions

View File

@@ -53,16 +53,6 @@ class VGroups:
order_by: Union[str, Iterable[str], None] = None, order_by: Union[str, Iterable[str], None] = None,
) -> Iterator['GroupBySource']: ) -> Iterator['GroupBySource']:
''' Extract all referencing groupby virtual objects from a page. ''' ''' Extract all referencing groupby virtual objects from a page. '''
ctx = get_ctx()
if not ctx:
raise NotImplementedError("Shouldn't happen, where is my context?")
# get GroupBy object
builder = ctx.build_state.builder
groupby = GroupByRef.of(builder)
groupby.make_once(builder) # ensure did cluster before
# manage config dependencies
for dep in groupby.dependencies:
ctx.record_dependency(dep)
# prepare filter # prepare filter
if isinstance(keys, str): if isinstance(keys, str):
keys = [keys] keys = [keys]
@@ -70,6 +60,13 @@ class VGroups:
fields = [fields] fields = [fields]
if isinstance(flows, str): if isinstance(flows, str):
flows = [flows] flows = [flows]
# get GroupBy object
ctx = get_ctx()
if not ctx:
raise NotImplementedError("Shouldn't happen, where is my context?")
builder = ctx.build_state.builder
# TODO: fix record_dependency -> process in non-capturing context
GroupByRef.of(builder).make_once(keys) # ensure did cluster before use
# find groups # find groups
proc_list = [record] proc_list = [record]
done_list = set() # type: Set[GroupBySource] done_list = set() # type: Set[GroupBySource]
@@ -86,6 +83,13 @@ class VGroups:
continue continue
done_list.add(vobj()) done_list.add(vobj())
# manage config dependencies
deps = set() # type: Set[str]
for vobj in done_list:
deps.update(vobj.config.dependencies)
for dep in deps:
ctx.record_dependency(dep)
if order_by: if order_by:
if isinstance(order_by, str): if isinstance(order_by, str):
order = split_strip(order_by, ',') # type: Iterable[str] order = split_strip(order_by, ',') # type: Iterable[str]

View File

@@ -1,12 +1,33 @@
from inifile import IniFile from inifile import IniFile
from lektor.utils import slugify from lektor.environment import Expression
from lektor.context import Context
from lektor.utils import slugify as _slugify
from typing import TYPE_CHECKING
from typing import Set, Dict, Optional, Union, Any, List, Generator
from .util import split_strip from .util import split_strip
if TYPE_CHECKING:
from lektor.sourceobj import SourceObject
from typing import Set, Dict, Optional, Union, Any, List
AnyConfig = Union['Config', IniFile, Dict] AnyConfig = Union['Config', IniFile, Dict]
class ConfigError(Exception):
''' Used to print a Lektor console error. '''
def __init__(
self, key: str, field: str, expr: str, error: Union[Exception, str]
):
self.key = key
self.field = field
self.expr = expr
self.error = error
def __str__(self) -> str:
return 'Invalid config for [{}.{}] = "{}" Error: {}'.format(
self.key, self.field, self.expr, repr(self.error))
class Config: class Config:
''' '''
Holds information for GroupByWatcher and GroupBySource. Holds information for GroupByWatcher and GroupBySource.
@@ -22,11 +43,15 @@ class Config:
root: Optional[str] = None, # default: "/" root: Optional[str] = None, # default: "/"
slug: Optional[str] = None, # default: "{attr}/{group}/index.html" slug: Optional[str] = None, # default: "{attr}/{group}/index.html"
template: Optional[str] = None, # default: "groupby-{attr}.html" template: Optional[str] = None, # default: "groupby-{attr}.html"
replace_none_key: Optional[str] = None, # default: None
key_map_fn: Optional[str] = None, # default: None
) -> None: ) -> None:
self.key = key self.key = key
self.root = (root or '/').rstrip('/') or '/' self.root = (root or '/').rstrip('/') or '/'
self.slug = slug or (key + '/{key}/') # key = GroupBySource.key self.slug = slug or (key + '/{key}/') # key = GroupBySource.key
self.template = template or f'groupby-{self.key}.html' self.template = template or f'groupby-{self.key}.html'
self.replace_none_key = replace_none_key
self.key_map_fn = key_map_fn
# editable after init # editable after init
self.enabled = True self.enabled = True
self.dependencies = set() # type: Set[str] self.dependencies = set() # type: Set[str]
@@ -37,7 +62,8 @@ class Config:
def slugify(self, k: str) -> str: def slugify(self, k: str) -> str:
''' key_map replace and slugify. ''' ''' key_map replace and slugify. '''
return slugify(self.key_map.get(k, k)) # type: ignore[no-any-return] rv = self.key_map.get(k, k)
return _slugify(rv) or rv # the `or` allows for example "_"
def set_fields(self, fields: Optional[Dict[str, Any]]) -> None: def set_fields(self, fields: Optional[Dict[str, Any]]) -> None:
''' '''
@@ -72,7 +98,7 @@ class Config:
def __repr__(self) -> str: def __repr__(self) -> str:
txt = '<GroupByConfig' txt = '<GroupByConfig'
for x in ['key', 'root', 'slug', 'template', 'enabled']: for x in ['enabled', 'key', 'root', 'slug', 'template', 'key_map_fn']:
txt += ' {}="{}"'.format(x, getattr(self, x)) txt += ' {}="{}"'.format(x, getattr(self, x))
txt += f' fields="{", ".join(self.fields)}"' txt += f' fields="{", ".join(self.fields)}"'
if self.order_by: if self.order_by:
@@ -87,6 +113,8 @@ class Config:
root=cfg.get('root'), root=cfg.get('root'),
slug=cfg.get('slug'), slug=cfg.get('slug'),
template=cfg.get('template'), template=cfg.get('template'),
replace_none_key=cfg.get('replace_none_key'),
key_map_fn=cfg.get('key_map_fn'),
) )
@staticmethod @staticmethod
@@ -116,3 +144,56 @@ class Config:
return Config.from_ini(key, config) return Config.from_ini(key, config)
elif isinstance(config, Dict): elif isinstance(config, Dict):
return Config.from_dict(key, config) return Config.from_dict(key, config)
# -----------------------------------
# Field Expressions
# -----------------------------------
def _make_expression(self, expr: Any, *, on: 'SourceObject', field: str) \
-> Union[Expression, Any]:
''' Create Expression and report any config error. '''
if not isinstance(expr, str):
return expr
try:
return Expression(on.pad.env, expr)
except Exception as e:
raise ConfigError(self.key, field, expr, e)
def eval_field(self, attr: str, *, on: 'SourceObject') \
-> Union[Expression, Any]:
''' Create an expression for a custom defined user field. '''
# do not `gather_dependencies` because fields are evaluated on the fly
# dependency tracking happens whenever a field is accessed
return self._make_expression(
self.fields[attr], on=on, field='fields.' + attr)
def eval_slug(self, key: str, *, on: 'SourceObject') -> Optional[str]:
''' Either perform a "{key}" substitution or evaluate expression. '''
cfg_slug = self.slug
if not cfg_slug:
return None
if '{key}' in cfg_slug:
if key:
return cfg_slug.replace('{key}', key)
else:
raise ConfigError(self.key, 'slug', cfg_slug,
'Cannot replace {key} with None')
return None
else:
# TODO: do we need `gather_dependencies` here too?
expr = self._make_expression(cfg_slug, on=on, field='slug')
return expr.evaluate(on.pad, this=on, alt=on.alt) or None
def eval_key_map_fn(self, *, on: 'SourceObject', context: Dict) -> Any:
'''
If `key_map_fn` is set, evaluate field expression.
Note: The function does not check whether `key_map_fn` is set.
Return: A Generator result is automatically unpacked into a list.
'''
exp = self._make_expression(self.key_map_fn, on=on, field='key_map_fn')
with Context(pad=on.pad) as ctx:
with ctx.gather_dependencies(self.dependencies.add):
res = exp.evaluate(on.pad, this=on, alt=on.alt, values=context)
if isinstance(res, Generator):
res = list(res) # unpack for 1-to-n replacement
return res

View File

@@ -1,6 +1,7 @@
from lektor.builder import PathCache from lektor.builder import PathCache
from lektor.db import Record # isinstance from lektor.db import Record # isinstance
from typing import TYPE_CHECKING, Set, List from lektor.reporter import reporter # build
from typing import TYPE_CHECKING, List, Optional, Iterable
from .config import Config from .config import Config
from .watcher import Watcher from .watcher import Watcher
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -19,14 +20,14 @@ class GroupBy:
''' '''
def __init__(self, resolver: 'Resolver') -> None: def __init__(self, resolver: 'Resolver') -> None:
self._building = False
self._watcher = [] # type: List[Watcher] self._watcher = [] # type: List[Watcher]
self._results = [] # type: List[GroupBySource] self._results = [] # type: List[GroupBySource]
self.resolver = resolver self.resolver = resolver
self.didBuild = False
@property @property
def isNew(self) -> bool: def isBuilding(self) -> bool:
return not self.didBuild return self._building
def add_watcher(self, key: str, config: 'AnyConfig') -> Watcher: def add_watcher(self, key: str, config: 'AnyConfig') -> Watcher:
''' Init Config and add to watch list. ''' ''' Init Config and add to watch list. '''
@@ -34,15 +35,8 @@ class GroupBy:
self._watcher.append(w) self._watcher.append(w)
return w return w
def get_dependencies(self) -> Set[str]:
deps = set() # type: Set[str]
for w in self._watcher:
deps.update(w.config.dependencies)
return deps
def queue_all(self, builder: 'Builder') -> None: def queue_all(self, builder: 'Builder') -> None:
''' Iterate full site-tree and queue all children. ''' ''' Iterate full site-tree and queue all children. '''
self.dependencies = self.get_dependencies()
# remove disabled watchers # remove disabled watchers
self._watcher = [w for w in self._watcher if w.config.enabled] self._watcher = [w for w in self._watcher if w.config.enabled]
if not self._watcher: if not self._watcher:
@@ -61,16 +55,27 @@ class GroupBy:
if isinstance(record, Record): if isinstance(record, Record):
for w in self._watcher: for w in self._watcher:
if w.should_process(record): if w.should_process(record):
w.process(record) w.remember(record)
def make_once(self, builder: 'Builder') -> None: def make_once(self, filter_keys: Optional[Iterable[str]] = None) -> None:
''' Perform groupby, iter over sources with watcher callback. ''' '''
self.didBuild = True Perform groupby, iter over sources with watcher callback.
if self._watcher: If `filter_keys` is set, ignore all other watchers.
'''
if not self._watcher:
return
# not really necessary but should improve performance of later reset()
if not filter_keys:
self.resolver.reset() self.resolver.reset()
remaining = []
for w in self._watcher: for w in self._watcher:
root = builder.pad.get(w.config.root) # only process vobjs that are used somewhere
for vobj in w.iter_sources(root): if filter_keys and w.config.key not in filter_keys:
remaining.append(w)
continue
self.resolver.reset(w.config.key)
# these are used in the current context (or on `build_all`)
for vobj in w.iter_sources():
# add original source # add original source
self._results.append(vobj) self._results.append(vobj)
self.resolver.add(vobj) self.resolver.add(vobj)
@@ -78,14 +83,30 @@ class GroupBy:
for sub_vobj in vobj.__iter_pagination_sources__(): for sub_vobj in vobj.__iter_pagination_sources__():
self._results.append(sub_vobj) self._results.append(sub_vobj)
self.resolver.add(sub_vobj) self.resolver.add(sub_vobj)
self._watcher.clear() # TODO: if this should ever run concurrently, pop() from watchers
self._watcher = remaining
def build_all(self, builder: 'Builder') -> None: def build_all(
''' Create virtual objects and build sources. ''' self,
self.make_once(builder) # in case no page used the |vgroups filter builder: 'Builder',
specific: Optional['GroupBySource'] = None
) -> None:
'''
Build actual artifacts (if needed).
If `specific` is set, only build the artifacts for that single vobj
'''
if not self._watcher and not self._results:
return
with reporter.build('groupby', builder): # type:ignore
# in case no page used the |vgroups filter
self.make_once([specific.config.key] if specific else None)
self._building = True
path_cache = PathCache(builder.env) path_cache = PathCache(builder.env)
for vobj in self._results: for vobj in self._results:
if specific and vobj.path != specific.path:
continue
if vobj.slug: if vobj.slug:
builder.build(vobj, path_cache) builder.build(vobj, path_cache)
del path_cache del path_cache
self._building = False
self._results.clear() # garbage collect weak refs self._results.clear() # garbage collect weak refs

View File

@@ -48,6 +48,7 @@ class ModelReader:
for r_key, subs in self._models.get(record.datamodel.id, {}).items(): for r_key, subs in self._models.get(record.datamodel.id, {}).items():
field = record[r_key] field = record[r_key]
if not field: if not field:
yield FieldKeyPath(r_key), field
continue continue
if subs == '*': # either normal field or flow type (all blocks) if subs == '*': # either normal field or flow type (all blocks)
if self.flatten and isinstance(field, Flow): if self.flatten and isinstance(field, Flow):

View File

@@ -7,7 +7,7 @@ from .pruner import prune
from .resolver import Resolver from .resolver import Resolver
from .vobj import VPATH, GroupBySource, GroupByBuildProgram from .vobj import VPATH, GroupBySource, GroupByBuildProgram
if TYPE_CHECKING: if TYPE_CHECKING:
from lektor.builder import Builder, BuildState from lektor.builder import Builder
from lektor.sourceobj import SourceObject from lektor.sourceobj import SourceObject
from .watcher import GroupByCallbackArgs from .watcher import GroupByCallbackArgs
@@ -17,7 +17,6 @@ class GroupByPlugin(Plugin):
description = 'Cluster arbitrary records with field attribute keyword.' description = 'Cluster arbitrary records with field attribute keyword.'
def on_setup_env(self, **extra: Any) -> None: def on_setup_env(self, **extra: Any) -> None:
self.has_changes = False
self.resolver = Resolver(self.env) self.resolver = Resolver(self.env)
self.env.add_build_program(GroupBySource, GroupByBuildProgram) self.env.add_build_program(GroupBySource, GroupByBuildProgram)
self.env.jinja_env.filters.update(vgroups=VGroups.iter) self.env.jinja_env.filters.update(vgroups=VGroups.iter)
@@ -30,19 +29,13 @@ class GroupByPlugin(Plugin):
if isinstance(source, Asset): if isinstance(source, Asset):
return return
groupby = self._init_once(builder) groupby = self._init_once(builder)
if groupby.isNew and isinstance(source, GroupBySource): if not groupby.isBuilding and isinstance(source, GroupBySource):
self.has_changes = True # TODO: differentiate between actual build and browser preview
groupby.build_all(builder, source)
def on_after_build(self, build_state: 'BuildState', **extra: Any) -> None:
if build_state.updated_artifacts:
self.has_changes = True
def on_after_build_all(self, builder: 'Builder', **extra: Any) -> None: def on_after_build_all(self, builder: 'Builder', **extra: Any) -> None:
# only rebuild if has changes (bypass idle builds) # by now, most likely already built. So, build_all() is a no-op
# or the very first time after startup (url resolver & pruning) self._init_once(builder).build_all(builder)
if self.has_changes or not self.resolver.has_any:
self._init_once(builder).build_all(builder) # updates resolver
self.has_changes = False
def on_after_prune(self, builder: 'Builder', **extra: Any) -> None: def on_after_prune(self, builder: 'Builder', **extra: Any) -> None:
# TODO: find a better way to prune unreferenced elements # TODO: find a better way to prune unreferenced elements
@@ -78,7 +71,11 @@ class GroupByPlugin(Plugin):
@watcher.grouping() @watcher.grouping()
def _fn(args: 'GroupByCallbackArgs') -> Iterator[str]: def _fn(args: 'GroupByCallbackArgs') -> Iterator[str]:
val = args.field val = args.field
if isinstance(val, str): if isinstance(val, str) and val != '':
val = map(str.strip, val.split(split)) if split else [val] val = map(str.strip, val.split(split)) if split else [val]
elif isinstance(val, (bool, int, float)):
val = [val]
elif not val: # after checking for '', False, 0, and 0.0
val = [None]
if isinstance(val, (list, map)): if isinstance(val, (list, map)):
yield from val yield from val

View File

@@ -42,8 +42,12 @@ class Resolver:
def files(self) -> Iterable[str]: def files(self) -> Iterable[str]:
return self._data return self._data
def reset(self) -> None: def reset(self, optional_key: Optional[str] = None) -> None:
''' Clear previously recorded virtual objects. ''' ''' Clear previously recorded virtual objects. '''
if optional_key:
self._data = {k: v for k, v in self._data.items()
if v.config.key != optional_key}
else:
self._data.clear() self._data.clear()
def add(self, vobj: GroupBySource) -> None: def add(self, vobj: GroupBySource) -> None:

View File

@@ -1,20 +1,9 @@
from lektor.reporter import reporter, style
from typing import List, Dict, Optional, TypeVar from typing import List, Dict, Optional, TypeVar
from typing import Callable, Any, Union, Generic from typing import Callable, Any, Union, Generic
T = TypeVar('T') T = TypeVar('T')
def report_config_error(key: str, field: str, val: str, e: Exception) -> None:
''' Send error message to Lektor reporter. Indicate which field is bad. '''
msg = '[ERROR] invalid config for [{}.{}] = "{}", Error: {}'.format(
key, field, val, repr(e))
try:
reporter._write_line(style(msg, fg='red'))
except Exception:
print(msg) # fallback in case Lektor API changes
def most_used_key(keys: List[T]) -> Optional[T]: def most_used_key(keys: List[T]) -> Optional[T]:
''' Find string with most occurrences. ''' ''' Find string with most occurrences. '''
if len(keys) < 3: if len(keys) < 3:
@@ -58,3 +47,16 @@ def build_url(parts: List[str]) -> str:
if '.' not in url.rsplit('/', 1)[-1]: if '.' not in url.rsplit('/', 1)[-1]:
url += '/' url += '/'
return url or '/' return url or '/'
class cached_property(Generic[T]):
''' Calculate complex property only once. '''
def __init__(self, fn: Callable[[Any], T]) -> None:
self.fn = fn
def __get__(self, obj: object, typ: Union[type, None] = None) -> T:
if obj is None:
return self # type: ignore
ret = obj.__dict__[self.fn.__name__] = self.fn(obj)
return ret

View File

@@ -3,14 +3,11 @@ from lektor.context import get_ctx
from lektor.db import _CmpHelper from lektor.db import _CmpHelper
from lektor.environment import Expression from lektor.environment import Expression
from lektor.sourceobj import VirtualSourceObject # subclass from lektor.sourceobj import VirtualSourceObject # subclass
from werkzeug.utils import cached_property from typing import TYPE_CHECKING
from typing import List, Any, Dict, Optional, Generator, Iterator, Iterable
from typing import TYPE_CHECKING, List, Any, Optional, Iterator, Iterable
from .pagination import PaginationConfig from .pagination import PaginationConfig
from .query import FixedRecordsQuery from .query import FixedRecordsQuery
from .util import ( from .util import most_used_key, insert_before_ext, build_url, cached_property
report_config_error, most_used_key, insert_before_ext, build_url
)
if TYPE_CHECKING: if TYPE_CHECKING:
from lektor.pagination import Pagination from lektor.pagination import Pagination
from lektor.builder import Artifact from lektor.builder import Artifact
@@ -40,60 +37,64 @@ class GroupBySource(VirtualSourceObject):
super().__init__(record) super().__init__(record)
self.key = slug self.key = slug
self.page_num = page_num self.page_num = page_num
self._expr_fields = {} # type: Dict[str, Expression]
self.__children = [] # type: List[str] self.__children = [] # type: List[str]
self.__group_map = [] # type: List[str] self.__group_map = [] # type: List[Any]
def append_child(self, child: 'Record', group: str) -> None: def append_child(self, child: 'Record', group: Any) -> None:
if child not in self.__children: if child not in self.__children:
self.__children.append(child.path) self.__children.append(child.path)
# TODO: rename group to value
# __group_map is later used to find most used group # __group_map is later used to find most used group
self.__group_map.append(group) self.__group_map.append(group)
def _update_attr(self, key: str, value: Any) -> None:
''' Set or remove Jinja evaluated Expression field. '''
if isinstance(value, Expression):
self._expr_fields[key] = value
try:
delattr(self, key)
except AttributeError:
pass
else:
if key in self._expr_fields:
del self._expr_fields[key]
setattr(self, key, value)
# ------------------------- # -------------------------
# Evaluate Extra Fields # Evaluate Extra Fields
# ------------------------- # -------------------------
def finalize(self, config: 'Config', group: Optional[str] = None) \ def finalize(self, config: 'Config', group: Optional[Any] = None) \
-> 'GroupBySource': -> 'GroupBySource':
self.config = config self.config = config
# make a sorted children query # make a sorted children query
self._query = FixedRecordsQuery(self.pad, self.__children, self.alt) self._query = FixedRecordsQuery(self.pad, self.__children, self.alt)
self._query._order_by = config.order_by self._query._order_by = config.order_by
del self.__children
# set group name # set group name
self.group = group or most_used_key(self.__group_map) self.group = group or most_used_key(self.__group_map)
# cleanup temporary data
del self.__children
del self.__group_map del self.__group_map
# evaluate slug Expression # evaluate slug Expression
self.slug = None # type: Optional[str] self.slug = config.eval_slug(self.key, on=self)
if config.slug and '{key}' in config.slug:
self.slug = config.slug.replace('{key}', self.key)
else:
self.slug = self._eval(config.slug, field='slug')
assert self.slug != Ellipsis, 'invalid config: ' + config.slug
if self.slug and self.slug.endswith('/index.html'): if self.slug and self.slug.endswith('/index.html'):
self.slug = self.slug[:-10] self.slug = self.slug[:-10]
# extra fields
for attr, expr in config.fields.items():
setattr(self, attr, self._eval(expr, field='fields.' + attr))
return self
def _eval(self, value: Any, *, field: str) -> Any: if group: # exit early if initialized through resolver
''' Internal only: evaluates Lektor config file field expression. ''' return self
if not isinstance(value, str): # extra fields
return value for attr in config.fields:
pad = self.record.pad self._update_attr(attr, config.eval_field(attr, on=self))
alt = self.record.alt return self
try:
return Expression(pad.env, value).evaluate(pad, this=self, alt=alt)
except Exception as e:
report_config_error(self.config.key, field, value, e)
return Ellipsis
# ----------------------- # -----------------------
# Pagination handling # Pagination handling
# ----------------------- # -----------------------
@property
def supports_pagination(self) -> bool:
return self.config.pagination['enabled'] # type: ignore[no-any-return]
@cached_property @cached_property
def _pagination_config(self) -> 'PaginationConfig': def _pagination_config(self) -> 'PaginationConfig':
# Generate `PaginationConfig` once we need it # Generate `PaginationConfig` once we need it
@@ -128,25 +129,30 @@ class GroupBySource(VirtualSourceObject):
vpath += '/' + str(self.page_num) vpath += '/' + str(self.page_num)
return vpath return vpath
@property @cached_property
def url_path(self) -> str: def url_path(self) -> str: # type: ignore[override]
# Actual path to resource as seen by the browser ''' Actual path to resource as seen by the browser. '''
# check if slug is absolute URL
slug = self.slug
if slug and slug.startswith('/'):
parts = [self.pad.get_root(alt=self.alt).url_path]
else:
parts = [self.record.url_path] parts = [self.record.url_path]
# slug can be None!! # slug can be None!!
if not self.slug: if not slug:
return build_url(parts) return build_url(parts)
# if pagination enabled, append pagination.url_suffix to path # if pagination enabled, append pagination.url_suffix to path
if self.page_num and self.page_num > 1: if self.page_num and self.page_num > 1:
sffx = self._pagination_config.url_suffix sffx = self._pagination_config.url_suffix
if '.' in self.slug.split('/')[-1]: if '.' in slug.rsplit('/', 1)[-1]:
# default: ../slugpage2.html (use e.g.: url_suffix = .page.) # default: ../slugpage2.html (use e.g.: url_suffix = .page.)
parts.append(insert_before_ext( parts.append(insert_before_ext(
self.slug, sffx + str(self.page_num), '.')) slug, sffx + str(self.page_num), '.'))
else: else:
# default: ../slug/page/2/index.html # default: ../slug/page/2/index.html
parts += [self.slug, sffx, self.page_num] parts += [slug, sffx, self.page_num]
else: else:
parts.append(self.slug) parts.append(slug)
return build_url(parts) return build_url(parts)
def iter_source_filenames(self) -> Generator[str, None, None]: def iter_source_filenames(self) -> Generator[str, None, None]:
@@ -188,9 +194,23 @@ class GroupBySource(VirtualSourceObject):
return getattr(self, key[1:]) return getattr(self, key[1:])
return self.__missing__(key) return self.__missing__(key)
def __getattr__(self, key: str) -> Any:
''' Lazy evaluate custom user field expressions. '''
if key in self._expr_fields:
expr = self._expr_fields[key]
return expr.evaluate(self.pad, this=self, alt=self.alt)
raise AttributeError
def __lt__(self, other: 'GroupBySource') -> bool: def __lt__(self, other: 'GroupBySource') -> bool:
# Used for |sort filter ("group" is the provided original string) # Used for |sort filter ("group" is the provided original string)
return self.group.lower() < other.group.lower() if isinstance(self.group, (bool, int, float)) and \
isinstance(other.group, (bool, int, float)):
return self.group < other.group
if self.group is None:
return False
if other.group is None:
return True
return str(self.group).lower() < str(other.group).lower()
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
# Used for |unique filter # Used for |unique filter

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Tuple, Any, Union, NamedTuple from typing import TYPE_CHECKING, Dict, List, Any, Union, NamedTuple
from typing import Optional, Callable, Iterator, Generator from typing import Optional, Callable, Iterator, Generator
from .backref import VGroups from .backref import VGroups
from .model import ModelReader from .model import ModelReader
@@ -16,8 +16,8 @@ class GroupByCallbackArgs(NamedTuple):
GroupingCallback = Callable[[GroupByCallbackArgs], Union[ GroupingCallback = Callable[[GroupByCallbackArgs], Union[
Iterator[Union[str, Tuple[str, Any]]], Iterator[Any],
Generator[Union[str, Tuple[str, Any]], Optional[str], None], Generator[Optional[str], Any, None],
]] ]]
@@ -49,7 +49,8 @@ class Watcher:
assert callable(self.callback), 'No grouping callback provided.' assert callable(self.callback), 'No grouping callback provided.'
self._model_reader = ModelReader(pad.db, self.config.key, self.flatten) self._model_reader = ModelReader(pad.db, self.config.key, self.flatten)
self._root_record = {} # type: Dict[str, Record] self._root_record = {} # type: Dict[str, Record]
self._state = {} # type: Dict[str, Dict[str, GroupBySource]] self._state = {} # type: Dict[str, Dict[Optional[str], GroupBySource]]
self._rmmbr = [] # type: List[Record]
for alt in pad.config.iter_alternatives(): for alt in pad.config.iter_alternatives():
self._root_record[alt] = pad.get(self._root, alt=alt) self._root_record[alt] = pad.get(self._root, alt=alt)
self._state[alt] = {} self._state[alt] = {}
@@ -64,13 +65,15 @@ class Watcher:
Each record is guaranteed to be processed only once. Each record is guaranteed to be processed only once.
''' '''
for key, field in self._model_reader.read(record): for key, field in self._model_reader.read(record):
_gen = self.callback(GroupByCallbackArgs(record, key, field)) args = GroupByCallbackArgs(record, key, field)
_gen = self.callback(args)
try: try:
group = next(_gen) group = next(_gen)
while True: while True:
if not isinstance(group, str): if self.config.key_map_fn:
raise TypeError(f'Unsupported groupby yield: {group}') slug = self._persist_multiple(args, group)
slug = self._persist(record, key, group) else:
slug = self._persist(args, group)
# return slugified group key and continue iteration # return slugified group key and continue iteration
if isinstance(_gen, Generator) and not _gen.gi_yieldfrom: if isinstance(_gen, Generator) and not _gen.gi_yieldfrom:
group = _gen.send(slug) group = _gen.send(slug)
@@ -79,24 +82,57 @@ class Watcher:
except StopIteration: except StopIteration:
del _gen del _gen
def _persist(self, record: 'Record', key: 'FieldKeyPath', group: str) \ def _persist_multiple(self, args: 'GroupByCallbackArgs', obj: Any) \
-> str: -> Optional[str]:
# if custom key mapping function defined, use that first
res = self.config.eval_key_map_fn(on=args.record,
context={'X': obj, 'SRC': args})
if isinstance(res, (list, tuple)):
for k in res:
self._persist(args, k) # 1-to-n replacement
return None
return self._persist(args, res) # normal & null replacement
def _persist(self, args: 'GroupByCallbackArgs', obj: Any) \
-> Optional[str]:
''' Update internal state. Return slugified string. ''' ''' Update internal state. Return slugified string. '''
alt = record.alt if not isinstance(obj, (str, bool, int, float)) and obj is not None:
slug = self.config.slugify(group) raise ValueError(
'Unsupported groupby yield type for [{}]:'
' {} (expected str, got {})'.format(
self.config.key, obj, type(obj).__name__))
if obj is None:
# if obj is not set, test if config.replace_none_key is set
slug = self.config.replace_none_key
obj = slug
else:
# if obj is set, apply config.key_map (convert int -> str)
slug = self.config.slugify(str(obj)) or None
# if neither custom mapping succeeded, do not process further
if not slug or obj is None:
return slug
# update internal object storage
alt = args.record.alt
if slug not in self._state[alt]: if slug not in self._state[alt]:
src = GroupBySource(self._root_record[alt], slug) src = GroupBySource(self._root_record[alt], slug)
self._state[alt][slug] = src self._state[alt][slug] = src
else: else:
src = self._state[alt][slug] src = self._state[alt][slug]
src.append_child(record, group) src.append_child(args.record, obj) # obj is used as "group" string
# reverse reference # reverse reference
VGroups.of(record).add(key, src) VGroups.of(args.record).add(args.key, src)
return slug return slug
def iter_sources(self, root: 'Record') -> Iterator[GroupBySource]: def remember(self, record: 'Record') -> None:
self._rmmbr.append(record)
def iter_sources(self) -> Iterator[GroupBySource]:
''' Prepare and yield GroupBySource elements. ''' ''' Prepare and yield GroupBySource elements. '''
for x in self._rmmbr:
self.process(x)
del self._rmmbr
for vobj_list in self._state.values(): for vobj_list in self._state.values():
for vobj in vobj_list.values(): for vobj in vobj_list.values():
yield vobj.finalize(self.config) yield vobj.finalize(self.config)