typing + smaller bugfixes

This commit is contained in:
relikd
2022-04-09 03:45:48 +02:00
parent a25b62d934
commit d0c5072d27
15 changed files with 415 additions and 199 deletions

View File

@@ -1,7 +1,7 @@
'''
Collection of tools to streamline data format conversion.
'''
__version__ = '1.0.0'
__version__ = '1.0.1'
# import sys
# if __name__ != '__main__':

View File

@@ -1,9 +1,10 @@
#!/usr/bin/env python3
import os
from argparse import ArgumentParser, ArgumentTypeError, FileType
from argparse import ArgumentParser, ArgumentTypeError, FileType, Namespace
from typing import Any
def DirType(string):
def DirType(string: str) -> str:
if os.path.isdir(string):
return string
raise ArgumentTypeError(
@@ -11,20 +12,20 @@ def DirType(string):
class Cli(ArgumentParser):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def arg(self, *args, **kwargs):
def arg(self, *args: Any, **kwargs: Any) -> None:
self.add_argument(*args, **kwargs)
def arg_bool(self, *args, **kwargs):
def arg_bool(self, *args: Any, **kwargs: Any) -> None:
self.add_argument(*args, **kwargs, action='store_true')
def arg_dir(self, *args, **kwargs):
def arg_dir(self, *args: Any, **kwargs: Any) -> None:
self.add_argument(*args, **kwargs, type=DirType)
def arg_file(self, *args, mode='r', **kwargs):
def arg_file(self, *args: Any, mode: str = 'r', **kwargs: Any) -> None:
self.add_argument(*args, **kwargs, type=FileType(mode))
def parse(self):
def parse(self) -> Namespace:
return self.parse_args()

View File

@@ -2,65 +2,101 @@
from sys import stderr
from threading import Timer
from datetime import datetime as date
from typing import List, Any, Optional, Iterable, Callable
CronCallback = Callable[[Any], None]
class RepeatTimer(Timer):
def run(self):
''' Repeatedly call function with defined time interval. '''
def run(self) -> None:
while not self.finished.wait(self.interval):
self.function(*self.args, **self.kwargs)
class Cron:
''' Call one or more functions with fixed time interval. '''
class Job:
def __init__(self, interval, callback, object=None):
''' Holds information about the interval and callback method. '''
def __init__(
self,
interval: int,
callback: CronCallback,
object: Any = None
):
self.interval = interval
self.callback = callback
self.object = object
def run(self, ts=0):
def run(self, ts: int = 0) -> None:
if self.interval > 0 and ts % self.interval == 0:
self.callback(self.object)
@staticmethod
def simple(interval: int, callback, arg=None, *, sleep=range(1, 8)):
def simple(
interval: int,
callback: CronCallback,
arg: Any = None,
*, sleep: Iterable[int] = range(1, 8)
) -> 'Cron':
''' Convenient initializer. Add job and start timer. '''
cron = Cron(sleep=sleep)
cron.add_job(interval, callback, arg)
cron.start()
return cron
def __init__(self, *, sleep=range(1, 8)):
def __init__(self, *, sleep: Iterable[int] = range(1, 8)):
self.sleep = sleep
self._timer = None
self._timer = None # type: Optional[RepeatTimer]
self._last_t = -1
self.clear()
def clear(self):
self.jobs = []
def clear(self) -> None:
''' Remove all previously added jobs. '''
self.jobs = [] # type: List[Cron.Job]
def add_job(self, interval: int, callback, arg=None):
def add_job(self, interval: int, callback: CronCallback, arg: Any = None) \
-> Job:
''' Create and queue a new job. '''
job = Cron.Job(interval, callback, arg)
self.push(job)
return job
def push(self, job):
def push(self, job: Job) -> None:
''' Queue an existing job. '''
assert isinstance(job, Cron.Job), type(job)
self.jobs.append(job)
def pop(self, key):
def pop(self, key: str) -> Job:
''' Return and remove job with known key. '''
return self.jobs.pop(self.jobs.index(self.get(key)))
def get(self, key):
for x in self.jobs:
obj = x.object
if not obj:
def get(self, key: str) -> Job:
''' Find job with known key. job.object must be list[0] or str. '''
for job in self.jobs:
x = job.object
if not x:
continue
if (isinstance(obj, list) and obj[0] == key) or obj == key:
return x
if (isinstance(x, (list, tuple)) and x[0] == key) or x == key:
return job
raise KeyError('Key not found: ' + str(key))
# CSV import / export
def load_csv(self, fname: str, callback, *, cols: []):
def load_csv(
self,
fname: str,
callback: CronCallback,
*, cols: List[Callable[[str], Any]]
) -> int:
'''
Load comma separated CSV file. Return number of loaded jobs.
First column must be time interval.
`cols` is a list of value transformers, e.g., int, str, ...
'''
self.clear()
try:
with open(fname) as fp:
@@ -71,13 +107,14 @@ class Cron:
obj = [fn(o) if o else None for o, fn in zip(obj, cols)]
if len(obj) < len(cols):
obj += [None] * (len(cols) - len(obj))
self.add_job(int(time), callback, obj)
self.add_job(int(time or 0), callback, obj)
except FileNotFoundError:
print('File "{}" not found. No jobs loaded.'.format(fname),
file=stderr)
return len(self.jobs)
def save_csv(self, fname: str, *, cols: [str]):
def save_csv(self, fname: str, *, cols: List[str]) -> None:
''' Persist in-memory jobs to CSV file. `cols` are column headers. '''
with open(fname, 'w') as fp:
fp.write(' , '.join(['# interval'] + cols) + '\n')
for job in self.jobs:
@@ -93,24 +130,28 @@ class Cron:
# Handle repeat timer
def start(self):
def start(self) -> None:
''' Start cron timer interval. Check every 15 sec. '''
if not self._timer:
self._timer = RepeatTimer(15, self._callback)
self._timer.start() # cancel()
def stop(self):
def stop(self) -> None:
''' Stop or pause timer. '''
if self._timer:
if self._timer.is_alive():
self._timer.cancel()
self._timer = None
def fire(self):
def fire(self) -> None:
''' Run all jobs immediatelly. '''
now = date.now()
self._last_t = now.day * 1440 + now.hour * 60 + now.minute
for job in self.jobs:
job.run()
def _callback(self):
def _callback(self) -> None:
''' [internal] check if interval matches current time and execute. '''
now = date.now()
if now.hour in self.sleep:
return
@@ -122,6 +163,6 @@ class Cron:
for job in self.jobs:
job.run(ts)
def __str__(self):
def __str__(self) -> str:
return '\n'.join('@{}m {}'.format(job.interval, job.object)
for job in self.jobs)

View File

@@ -4,46 +4,53 @@ import json
from sys import stderr
from hashlib import md5
from urllib.error import HTTPError, URLError
from urllib.parse import urlparse
from urllib.parse import urlparse, ParseResult
from urllib.request import urlretrieve, urlopen, Request
from typing import List, Dict, Optional, Any, TextIO
from datetime import datetime # typing
from http.client import HTTPResponse # typing
from .helper import FileTime
import ssl
# somehow macOS default behavior for SSL verification is broken
ssl._create_default_https_context = ssl._create_unverified_context
def _read_modified_header(fname: str): # dict or None
if not os.path.isfile(fname):
return None
def _read_modified_header(fname: str) -> Dict[str, str]:
''' Extract Etag and Last-Modified headers, rename for sending. '''
res = {}
with open(fname) as fp:
head = dict(x.strip().split(': ', 1) for x in fp.readlines())
etag = head.get('Etag')
if etag:
res['If-None-Match'] = etag
lastmod = head.get('Last-Modified')
if lastmod:
res['If-Modified-Since'] = lastmod.replace('-gzip', '')
return res or None
if os.path.isfile(fname):
with open(fname) as fp:
for line in fp.readlines():
key, val = line.strip().split(': ', 1)
if key == 'Etag' and val:
res['If-None-Match'] = val
elif key == 'Last-Modified' and val:
res['If-Modified-Since'] = val.replace('-gzip', '')
return res
class Curl:
''' Rename Curl.CACHE_DIR to move the cache somewhere else. '''
CACHE_DIR = 'cache'
@staticmethod
def valid_url(url):
def valid_url(url: str) -> Optional[ParseResult]:
''' If valid, return urlparse() result. '''
url = url.strip().replace(' ', '+')
x = urlparse(url)
return x if x.scheme and x.netloc else None
@staticmethod
def url_hash(url) -> str:
def url_hash(url: str) -> str:
''' Unique url-hash used for filename / storage. '''
x = Curl.valid_url(url)
return '{}-{}'.format(x.hostname if x else 'ERR',
md5(url.encode()).hexdigest())
@staticmethod
def open(url: str, *, headers={}): # url-open-pointer or None
def open(url: str, *, headers: Optional[Dict[str, str]] = None) \
-> Optional[HTTPResponse]:
''' Open a network connection, returl urlopen() result or None. '''
try:
head = {'User-Agent': 'Mozilla/5.0'}
if headers:
@@ -57,7 +64,7 @@ class Curl:
return None
@staticmethod
def get(url: str, *, cache_only=False): # file-pointer
def get(url: str, *, cache_only: bool = False) -> Optional[TextIO]:
'''
Returns an already open file pointer.
You are responsible for closing the file.
@@ -74,17 +81,19 @@ class Curl:
if conn:
with open(fname_head, 'w') as fp:
fp.write(str(conn.info()).strip())
with open(fname, 'wb') as fp:
with open(fname, 'wb') as fpb:
while True:
data = conn.read(8192) # 1024 Bytes
if not data:
break
fp.write(data)
if os.path.isfile(fname):
return open(fname)
fpb.write(data)
return open(fname) if os.path.isfile(fname) else None
@staticmethod
def json(url: str, fallback=None, *, cache_only=False) -> object:
def json(url: str, fallback: Any = None, *, cache_only: bool = False) \
-> Any:
''' Open network connection and download + parse json result. '''
conn = Curl.get(url, cache_only=cache_only)
if not conn:
return fallback
@@ -92,11 +101,15 @@ class Curl:
return json.load(fp)
@staticmethod
def file(url: str, dest_path: str, *, raise_except=False) -> bool:
tmp_file = dest_path + '.inprogress'
def file(url: str, dest_file: str, *, raise_except: bool = False) -> bool:
'''
Download raw data to file. Creates an intermediate ".inprogress" file.
If raise_except = False, silently ignore errors (default).
'''
tmp_file = dest_file + '.inprogress'
try:
urlretrieve(url, tmp_file)
os.rename(tmp_file, dest_path) # atomic download, no broken files
os.rename(tmp_file, dest_file) # atomic download, no broken files
return True
except HTTPError as e:
# print('ERROR: Load URL "{}" -- {}'.format(url, e), file=stderr)
@@ -105,8 +118,23 @@ class Curl:
return False
@staticmethod
def once(dest_dir, fname, urllist, date=None, *,
override=False, dry_run=False, verbose=False, intro=''):
def once(
dest_dir: str,
fname: str,
urllist: List[str],
date: Optional[datetime] = None,
*, override: bool = False,
dry_run: bool = False,
verbose: bool = False,
intro: Optional[str] = None
) -> bool:
'''
Download and store a list of raw files. If local file exists, ignore.
`fname` should be the filename without extension. Extension is added
based on the extension in the `urllist` (per file).
If `date` is set, change last modified date of downloaded file.
Print `intro` before download (if any loaded or if `override`).
'''
did_update = False
for url_str in urllist:
parts = Curl.valid_url(url_str)

View File

@@ -1,10 +1,15 @@
#!/usr/bin/env python3
import xml.etree.ElementTree as ET
from typing import List, Dict, Any, Optional, Union, TextIO, BinaryIO
from .helper import StrFormat
def Feed2List(fp, *, keys=[]):
def parse_xml_without_namespace(file):
def Feed2List(
fp: Optional[Union[TextIO, BinaryIO]],
*, keys: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
''' Parse RSS or Atom feed and return list of entries. '''
def parse_without_namespace(file: Union[TextIO, BinaryIO]) -> ET.Element:
ns = {}
xml_iter = ET.iterparse(file, ('start-ns', 'start'))
for event, elem in xml_iter:
@@ -15,8 +20,10 @@ def Feed2List(fp, *, keys=[]):
elem.tag = ''.join(ns[x] for x in tag[:-1]) + tag[-1]
return xml_iter.root
if not fp:
return []
# detect feed format (RSS / Atom)
root = parse_xml_without_namespace(fp)
root = parse_without_namespace(fp)
fp.close()
if root.tag == 'rss': # RSS
selector = 'channel/item'
@@ -30,7 +37,7 @@ def Feed2List(fp, *, keys=[]):
# parse XML
result = []
for item in root.findall(selector):
obj = {}
obj = {} # type: Dict[str, Any]
for child in item:
tag = child.tag
# Filter keys that are clearly not wanted by user
@@ -48,9 +55,9 @@ def Feed2List(fp, *, keys=[]):
value = attr
# Auto-create list type if duplicate keys are used
try:
obj[tag]
if not isinstance(obj[tag], list):
obj[tag] = [obj[tag]]
prev_val = obj[tag]
if not isinstance(prev_val, list):
obj[tag] = [prev_val]
obj[tag].append(value)
except KeyError:
obj[tag] = value

View File

@@ -7,26 +7,31 @@ from html import unescape
from datetime import datetime
import unicodedata # normalize
from string import ascii_letters, digits
from typing import Optional, Callable, Union
class Log:
@staticmethod
def error(e):
def error(e: str) -> None:
''' Log error message (incl. current timestamp) '''
print('{} [ERROR] {}'.format(datetime.now(), e), file=stderr)
@staticmethod
def info(m):
def info(m: str) -> None:
''' Log info message (incl. current timestamp) '''
print('{} {}'.format(datetime.now(), m))
class FileTime:
@staticmethod
def set(fname, date):
def set(fname: str, date: datetime) -> None:
''' Set file modification time. '''
modTime = time.mktime(date.timetuple())
os.utime(fname, (modTime, modTime))
@staticmethod
def get(fname, *, absolute=False):
def get(fname: str, *, absolute: bool = False) -> float:
''' Get file modification time. '''
x = os.path.getmtime(fname)
return x if absolute else time.time() - x
@@ -40,7 +45,11 @@ class StrFormat:
re_crlf = re.compile(r'[\n\r]{2,}')
@staticmethod
def strip_html(text):
def strip_html(text: str) -> str:
'''
Remove all html tags and replace with readble alternative.
Also, strips unnecessary newlines, nbsp, br, etc.
'''
text = StrFormat.re_img.sub(r'[IMG: \2, \1\3]', text)
text = StrFormat.re_href.sub(r'\2 (\1)', text)
text = StrFormat.re_br.sub('\n', text)
@@ -49,7 +58,8 @@ class StrFormat:
return unescape(text).replace(' ', ' ').strip()
@staticmethod
def to_date(text):
def to_date(text: str) -> datetime:
''' Try parse string as date, currently RSS + Atom format. '''
for date_format in (
'%a, %d %b %Y %H:%M:%S %z', # RSS
'%Y-%m-%dT%H:%M:%S%z', # Atom
@@ -66,18 +76,31 @@ class StrFormat:
fnameChars = set('-_.,() {}{}'.format(ascii_letters, digits))
@staticmethod
def safe_filename(text):
def safe_filename(text: str) -> str:
''' Replace umlauts and unsafe characters (filesystem safe). '''
text = unicodedata.normalize('NFKD', text) # makes 2-bytes of umlauts
text = text.replace('̈', 'e') # replace umlauts e.g., Ä -> Ae
text = text.encode('ASCII', 'ignore')
return ''.join(chr(c) for c in text if chr(c) in StrFormat.fnameChars)
data = text.encode('ASCII', 'ignore')
return ''.join(chr(c) for c in data if chr(c) in StrFormat.fnameChars)
class FileWrite:
@staticmethod
def once(dest_dir, fname, date=None, *,
override=False, dry_run=False, verbose=False, intro=''):
def _decorator(func):
def once(
dest_dir: str,
fname: str,
date: Optional[datetime] = None,
*, override: bool = False,
dry_run: bool = False,
verbose: bool = False,
intro: Union[str, bool, None] = None
) -> Callable[[Callable[[], Optional[str]]], None]:
'''
Write file to disk but only if it does not exist already.
The callback method is only called if the file does not exist yet.
Use as decorator to a function: @FileWrite.once(...)
'''
def _decorator(func: Callable[[], Optional[str]]) -> None:
path = os.path.join(dest_dir, fname)
if os.path.isfile(path) and not override:
return
@@ -85,7 +108,7 @@ class FileWrite:
if not content:
return
if verbose:
if intro and not isinstance(intro, bool):
if intro and intro is not True:
print(intro)
print(' >', path)
if dry_run:

View File

@@ -3,24 +3,29 @@ import re
import json
from sys import stderr
from argparse import ArgumentParser, FileType
from typing import List, Tuple, Dict, Optional, Union, Callable
from typing import TextIO, BinaryIO, Iterator, KeysView
from html.parser import HTMLParser
XMLAttrs = List[Tuple[str, Optional[str]]]
class CSSSelector:
''' Limited support, match single tag with classes: div.class.other '''
def __init__(self, selector):
def __init__(self, selector: str) -> None:
if any(x in ' >+' for x in selector):
raise NotImplementedError(
'No support for nested tags. "{}"'.format(selector))
self.tag, *self.cls = selector.split('.')
def matches(self, tag, attrs):
def matches(self, tag: str, attrs: XMLAttrs) -> bool:
''' Test if tag and attrs match the target selector. '''
if self.tag and tag != self.tag:
return False
if self.cls:
for k, val in attrs:
if k == 'class':
if k == 'class' and val:
classes = val.split()
return all(x in classes for x in self.cls)
return False
@@ -34,27 +39,31 @@ class HTML2List(HTMLParser):
If not set, return a list of strings instead.
'''
def __init__(self, select, callback=None):
def __init__(
self,
select: str,
callback: Optional[Callable[[str], str]] = None
) -> None:
super().__init__()
self._filter = CSSSelector(select)
self._data = '' # temporary data built-up
self._elem = [] # tag stack
self._elem = [] # type: List[str] # tag stack
self._tgt = 0 # remember matching level for filter
self._result = [] # empty if callback
self._result = [] # type: List[str] # empty if callback
self._callback = callback or self._result.append
def parse(self, source):
def parse(self, source: Optional[Union[TextIO, BinaryIO]]) -> List[str]:
'''
:source: A file-pointer or web-source with read() attribute.
Warning: return value empty if callback is set!
'''
def rb2str(data, fp, limit=256):
def rb2str(data: bytes, fp: BinaryIO, limit: int = 256) -> str:
try:
return data.decode('utf-8')
except UnicodeDecodeError:
extra = fp.read(limit)
if not extra:
return data
return data # type: ignore[return-value]
return rb2str(data + extra, fp, limit)
if not source:
@@ -63,37 +72,41 @@ class HTML2List(HTMLParser):
while True:
try:
data = source.read(65536) # 64k
if not data:
break
except Exception as e:
print('ERROR: {}'.format(e), file=stderr)
data = None
if not data:
break
if isinstance(data, bytes):
data = rb2str(data, source)
data = rb2str(data, source) # type: ignore[arg-type]
self.feed(data)
source.close()
self.close()
return self._result
def handle_starttag(self, tag, attrs):
def handle_starttag(self, tag: str, attrs: XMLAttrs) -> None:
''' [internal] HTMLParser callback '''
self._elem.append(tag)
if self._filter.matches(tag, attrs):
if self._tgt > 0:
raise RuntimeError('No nested tags! Adjust your filter.')
self._tgt = len(self._elem) - 1
if self._tgt > 0:
self._data += self.get_starttag_text()
self._data += self.get_starttag_text() or ''
def handle_startendtag(self, tag, attrs):
def handle_startendtag(self, tag: str, attrs: XMLAttrs) -> None:
''' [internal] HTMLParser callback '''
self._elem.append(tag)
if self._tgt > 0:
self._data += self.get_starttag_text()
self._data += self.get_starttag_text() or ''
def handle_data(self, data):
def handle_data(self, data: str) -> None:
''' [internal] HTMLParser callback '''
if self._tgt > 0:
self._data += data
def handle_endtag(self, tag):
def handle_endtag(self, tag: str) -> None:
''' [internal] HTMLParser callback '''
if self._tgt > 0:
self._data += '</{}>'.format(tag)
# drop any non-closed tags
@@ -117,43 +130,63 @@ class Grep:
'''
re_whitespace = re.compile(r'\s+') # will also replace newline with space
def __init__(self, regex, *, cleanup=True):
def __init__(self, regex: str, *, cleanup: bool = True) -> None:
self.cleanup = cleanup
self._rgx = re.compile(regex)
def find(self, text):
res = self._rgx.search(text)
if not res:
def find(self, text: str) -> Optional[str]:
''' Perform regex search to find desired snippet. '''
grp = self._rgx.search(text)
if not grp:
return None
res = res.groups()[0]
res = grp.groups()[0]
if self.cleanup:
return self.re_whitespace.sub(' ', res.strip())
return res
class MatchGroup:
class MatchGroup(dict):
''' Use {#tagname#} to replace values with regex value. '''
re_tag = re.compile(r'{#(.*?)#}')
def __init__(self, grepDict={}):
self._regex = {}
for k, v in grepDict.items():
def __init__(
self,
grepDict: Optional[Dict[str, Union[str, Grep]]] = None
) -> None:
self._regex = {} # type: Dict[str, Grep]
for k, v in (grepDict or {}).items():
self.add(k, v)
self.set_html('')
def add(self, tagname, regex, *, cleanup=True):
def add(
self,
tagname: str,
regex: Union[str, Grep],
*, cleanup: bool = True
) -> None:
''' Add a single search pattern to the internal table. '''
self._regex[tagname] = \
regex if isinstance(regex, Grep) else Grep(regex, cleanup=cleanup)
def set_html(self, html):
def set_html(self, html: str) -> 'MatchGroup':
''' Reuse existing MatchGroup but set new search html. '''
self._html = html
self._res = {}
self._res = {} # type: Dict[str, Optional[str]]
return self
def keys(self):
def keys(self) -> KeysView[str]:
''' Get all search keys. '''
return self._regex.keys()
def __getitem__(self, key):
def __iter__(self) -> Iterator[str]:
''' Iter is required for a dict subclass to support **unwrap. '''
return iter(self.keys())
def __getitem__(self, key: str) -> Optional[str]:
'''
Conditional getter. Regex search is only perfromed on access.
Once the search was performed, result is cached until `set_html()`.
'''
try:
return self._res[key]
except KeyError:
@@ -161,19 +194,21 @@ class MatchGroup:
self._res[key] = val
return val
def __str__(self):
return '\n'.join(
'{}: {}'.format(k, self._res.get(k, '<?>')) for k in self._regex)
def __str__(self) -> str:
return '\n'.join('{}: {}'.format(k, self._res.get(k, '<?>'))
for k in self._regex)
def to_dict(self):
def to_dict(self) -> Dict[str, Optional[str]]:
''' Force unwrap all keys and perform regex. '''
return {k: self[k] for k in self._regex}
def use_template(self, template):
def use_template(self, template: str) -> str:
''' Use {#tagname#} to replace values with regex value. '''
return self.re_tag.sub(lambda x: self[x.groups()[0]], template)
return self.re_tag.sub(lambda x: self[x.groups()[0]] or '', template)
def _cli():
def _cli() -> None:
''' CLI entry point. '''
parser = ArgumentParser()
parser.add_argument('FILE', type=FileType('r'), help='Input html file')
parser.add_argument('selector', help='CSS selector. E.g., article.entry')

View File

@@ -9,10 +9,13 @@ Usage: Load existing `OnceDB()` and `put(cohort, uid, obj)` new entries.
Once in a while call `cleanup()` to remove old entries.
'''
import sqlite3
from typing import Tuple, Any, Callable, Iterator
DBEntry = Tuple[int, str, str, Any]
class OnceDB:
def __init__(self, db_path):
def __init__(self, db_path: str) -> None:
self._db = sqlite3.connect(db_path)
self._db.execute('''
CREATE TABLE IF NOT EXISTS queue(
@@ -24,10 +27,10 @@ class OnceDB:
);
''')
def __del__(self):
def __del__(self) -> None:
self._db.close()
def cleanup(self, limit=200):
def cleanup(self, limit: int = 200) -> None:
''' Delete oldest (cohort) entries if more than limit exist. '''
self._db.execute('''
WITH _tmp AS (
@@ -41,7 +44,7 @@ class OnceDB:
''', (limit,))
self._db.commit()
def put(self, cohort, uid, obj):
def put(self, cohort: str, uid: str, obj: str) -> bool:
''' Silently ignore if a duplicate (cohort, uid) is added. '''
try:
self._db.execute('''
@@ -53,7 +56,8 @@ class OnceDB:
# entry (cohort, uid) already exists
return False
def contains(self, cohort, uid):
def contains(self, cohort: str, uid: str) -> bool:
''' Test if cohort + uid pair exists in database. '''
cur = self._db.cursor()
cur.execute('''
SELECT 1 FROM queue WHERE cohort IS ? AND uid is ? LIMIT 1;
@@ -62,7 +66,7 @@ class OnceDB:
cur.close()
return flag
def mark_done(self, rowid):
def mark_done(self, rowid: int) -> None:
''' Mark (ROWID) as done. Entry remains in cache until cleanup(). '''
if not isinstance(rowid, int):
raise AttributeError('Not of type ROWID: {}'.format(rowid))
@@ -70,12 +74,16 @@ class OnceDB:
(rowid, ))
self._db.commit()
def mark_all_done(self):
def mark_all_done(self) -> None:
''' Mark all entries done. Entry remains in cache until cleanup(). '''
self._db.execute('UPDATE queue SET obj = NULL;')
self._db.commit()
def foreach(self, callback, *, reverse=False):
def foreach(
self,
callback: Callable[[str, str, Any], bool],
*, reverse: bool = False
) -> bool:
'''
Exec for all until callback evaluates to false (or end of list).
Automatically marks entries as done (only on success).
@@ -87,16 +95,19 @@ class OnceDB:
return False
return True
def __iter__(self, *, reverse=False):
def __iter__(self) -> Iterator[DBEntry]:
return self.iter()
def __reversed__(self) -> Iterator[DBEntry]:
return self.iter(desc=True)
def iter(self, *, desc: bool = False) -> Iterator[DBEntry]:
''' Perform query on all un-marked / not-done entries. '''
cur = self._db.cursor()
cur.execute('''
SELECT ROWID, cohort, uid, obj FROM queue
WHERE obj IS NOT NULL
ORDER BY ROWID {};
'''.format('DESC' if reverse else 'ASC'))
'''.format('DESC' if desc else 'ASC'))
yield from cur.fetchall()
cur.close()
def __reversed__(self):
return self.__iter__(reverse=True)

View File

@@ -2,31 +2,37 @@
import telebot # pip3 install pytelegrambotapi
from threading import Thread
from time import sleep
from typing import List, Optional, Any, Union, Iterable, Callable
from telebot.types import Message, Chat # typing
from .helper import Log
class Kill(Exception):
''' Used to intentionally kill the bot. '''
pass
class TGClient(telebot.TeleBot):
@staticmethod
def listen_chat_info(api_key, user):
bot = TGClient(api_key, polling=True, allowedUsers=[user])
'''
Telegram client. Wrapper around telebot.TeleBot.
If `polling` if False, you can run the bot for a single send_message.
If `allowedUsers` is None, all users are allowed.
'''
@bot.message_handler(commands=['start'])
def handle_start(message):
bot.log_chat_info(message.chat)
raise Kill()
return bot
def __init__(self, apiKey, *, polling, allowedUsers=[], **kwargs):
def __init__(
self,
apiKey: str,
*, polling: bool,
allowedUsers: Optional[List[str]] = None,
**kwargs: Any
) -> None:
''' If '''
super().__init__(apiKey, **kwargs)
self.users = allowedUsers
self.onKillCallback = None
self.onKillCallback = None # type: Optional[Callable[[], None]]
if polling:
def _fn():
def _fn() -> None:
try:
Log.info('Ready')
self.polling(skip_pending=True) # none_stop=True
@@ -36,7 +42,7 @@ class TGClient(telebot.TeleBot):
self.onKillCallback()
return
except Exception as e:
Log.error(e)
Log.error(repr(e))
Log.info('Auto-restart in 15 sec ...')
sleep(15)
_fn()
@@ -44,45 +50,68 @@ class TGClient(telebot.TeleBot):
Thread(target=_fn, name='Polling').start()
@self.message_handler(commands=['?'])
def _healthcheck(message):
def _healthcheck(message: Message) -> None:
if self.allowed(message):
self.reply_to(message, 'yes')
@self.message_handler(commands=['kill'])
def _kill(message):
def _kill(message: Message) -> None:
if self.allowed(message):
self.reply_to(message, 'bye bye')
raise Kill()
def set_on_kill(self, callback):
def set_on_kill(self, callback: Optional[Callable[[], None]]) -> None:
''' Callback is executed when a Kill exception is raised. '''
self.onKillCallback = callback
@staticmethod
def listen_chat_info(api_key: str, user: str) -> 'TGClient':
''' Wait for a single /start command, print chat-id, then quit. '''
bot = TGClient(api_key, polling=True, allowedUsers=[user])
@bot.message_handler(commands=['start'])
def handle_start(message: Message) -> None:
bot.log_chat_info(message.chat)
raise Kill()
return bot
# Helper methods
def log_chat_info(self, chat):
def log_chat_info(self, chat: Chat) -> None:
''' Print current chat details (chat-id, title, etc.) to console. '''
Log.info('[INFO] chat-id: {} ({}, title: "{}")'.format(
chat.id, chat.type, chat.title or ''))
def allowed(self, src_msg):
def allowed(self, src_msg: Message) -> bool:
''' Return true if message is sent to an previously allowed user. '''
return not self.users or src_msg.from_user.username in self.users
def send(self, chat_id, msg, **kwargs):
def send(self, chat_id: int, msg: str, **kwargs: Any) -> Optional[Message]:
''' Send a message to chat. '''
try:
return self.send_message(chat_id, msg, **kwargs)
except Exception as e:
Log.error(e)
Log.error(repr(e))
sleep(45)
return None
def send_buttons(self, chat_id, msg, options):
def send_buttons(
self,
chat_id: int,
msg: str,
options: Iterable[Union[str, int, float]]
) -> Message:
''' Send tiling keyboard with predefined options to user. '''
markup = telebot.types.ReplyKeyboardMarkup(one_time_keyboard=True)
markup.add(*(telebot.types.KeyboardButton(x) for x in options))
markup.add(*(telebot.types.KeyboardButton(str(x)) for x in options))
return self.send_message(chat_id, msg, reply_markup=markup)
def send_abort_keyboard(self, src_msg, reply_msg):
def send_abort_keyboard(self, src_msg: Message, reply_msg: str) -> Message:
''' Cancel previously sent keyboards. '''
return self.reply_to(src_msg, reply_msg,
reply_markup=telebot.types.ReplyKeyboardRemove())
def send_force_reply(self, chat_id, msg):
def send_force_reply(self, chat_id: int, msg: str) -> Message:
''' Send a message which is automatically set to reply_to. '''
return self.send_message(chat_id, msg,
reply_markup=telebot.types.ForceReply())

View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python3
import os
from sys import stderr
from typing import Dict, Any, Optional, TextIO
from datetime import datetime # typing
from botlib.cli import Cli
from botlib.curl import Curl
@@ -8,7 +10,8 @@ from botlib.feed2list import Feed2List
from botlib.helper import StrFormat, FileWrite
def main():
def main() -> None:
''' CLI entry. '''
cli = Cli()
cli.arg_dir('dest_dir', help='Download all entries here')
cli.arg('source', help='RSS file or web-url')
@@ -25,10 +28,16 @@ def main():
print('ERROR: ' + str(e), file=stderr)
def process(source, dest_dir, *, by_year=False, dry_run=False):
def process(
source: str, # local file path or remote url
dest_dir: str,
*, by_year: bool = False,
dry_run: bool = False
) -> bool:
''' Parse a full podcast file / source. '''
# open source
if os.path.isfile(source):
fp = open(source) # closed in Feed2List
fp = open(source) # type: Optional[TextIO] # closed in Feed2List
elif Curl.valid_url(source):
fp = Curl.get(source) # closed in Feed2List
else:
@@ -41,7 +50,7 @@ def process(source, dest_dir, *, by_year=False, dry_run=False):
'pubDate', 'media:content', # image
# 'itunes:image', 'itunes:duration', 'itunes:summary'
])):
date = entry.get('pubDate') # try RSS only
date = entry['pubDate'] # try RSS only # type: datetime
if by_year:
dest = os.path.join(dest_dir, str(date.year))
if not dry_run and not os.path.exists(dest):
@@ -50,7 +59,13 @@ def process(source, dest_dir, *, by_year=False, dry_run=False):
return True
def process_entry(entry, date, dest_dir, *, dry_run=False):
def process_entry(
entry: Dict[str, Any],
date: datetime,
dest_dir: str,
*, dry_run: bool = False
) -> None:
''' Parse a single podcast media entry. '''
title = entry['title']
# <enclosure url="*.mp3" length="47216000" type="audio/mpeg"/>
audio_url = entry.get('enclosure', {}).get('url')
@@ -78,10 +93,11 @@ def process_entry(entry, date, dest_dir, *, dry_run=False):
@FileWrite.once(dest_dir, fname + '.txt', date, override=False,
dry_run=dry_run, verbose=True, intro=flag or intro)
def _description():
desc = title + '\n' + '=' * len(title)
desc += '\n\n' + StrFormat.strip_html(entry.get('description', ''))
return desc + '\n\n\n' + entry.get('link', '') + '\n'
def _description() -> str:
return '{}\n{}\n\n{}\n\n\n{}\n'.format(
title, '=' * len(title),
StrFormat.strip_html(entry.get('description', '')),
entry.get('link', ''))
if __name__ == '__main__':

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import os
from sys import stderr
from typing import Dict, Tuple, Optional, Any
from botlib.cli import Cli
from botlib.curl import Curl, URLError
@@ -15,7 +16,8 @@ db_slugs = OnceDB('radiolab_slugs.sqlite')
os.environ['TZ'] = 'America/New_York'
def main():
def main() -> None:
''' CLI entry. '''
cli = Cli()
cli.arg_dir('dest_dir', help='Download all episodes to dest_dir/year/')
cli.arg_bool('--dry-run', help='Do not download, just parse')
@@ -36,9 +38,17 @@ def main():
print('\nDone.\n\nNow check MP3 tags (consistency).')
def processEpisodeList(basedir, title, query, index=1, *, dry_run=False):
def processEpisodeList(
basedir: str,
title: str,
query: str,
index: int = 1,
*, dry_run: bool = False
) -> None:
''' Parse full podcast category. '''
print('\nProcessing: {}'.format(title), end='')
dat = Curl.json('{}/channel/shows/{}/{}?limit=9'.format(API, query, index))
url = '{}/channel/shows/{}/{}?limit=9'.format(API, query, index)
dat = Curl.json(url) # type: Dict[str, Any]
total = dat['data']['attributes']['total-pages']
print(' ({}/{})'.format(index, total))
anything_new = False
@@ -49,7 +59,12 @@ def processEpisodeList(basedir, title, query, index=1, *, dry_run=False):
processEpisodeList(basedir, title, query, index + 1, dry_run=dry_run)
def processEpisode(obj, basedir, *, dry_run=False):
def processEpisode(
obj: Dict[str, Any],
basedir: str,
*, dry_run: bool = False
) -> bool:
''' Parse a single podcast episode. '''
uid = obj['cms-pk']
if db_ids.contains(COHORT, uid):
return False # Already exists
@@ -86,18 +101,18 @@ def processEpisode(obj, basedir, *, dry_run=False):
@FileWrite.once(dest_dir, fname + '.txt', date, override=False,
dry_run=dry_run, verbose=True, intro=flag or intro)
def write_description():
def write_description() -> str:
nonlocal flag
flag = True
desc = title + '\n' + '=' * len(title)
desc += '\n\n' + StrFormat.strip_html(obj['body'])
desc = '{}\n{}\n\n{}'.format(
title, '=' * len(title), StrFormat.strip_html(obj['body']))
if img_desc:
desc += '\n\n' + img_desc
return desc + '\n\n\n' + obj['url'].strip() + '\n' # link to article
return '{}\n\n\n{}\n'.format(desc, obj['url'].strip()) # article link
@FileWrite.once(dest_dir, fname + '.transcript.txt', date, override=False,
dry_run=dry_run, verbose=True, intro=flag or intro)
def write_transcript():
def write_transcript() -> Optional[str]:
nonlocal flag
flag = True
data = StrFormat.strip_html(obj['transcript'])
@@ -111,7 +126,8 @@ def processEpisode(obj, basedir, *, dry_run=False):
return flag # potentially need to query the next page too
def get_img_desc(obj):
def get_img_desc(obj: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
''' Extract image description. '''
if not obj:
return (None, None)
url = (obj['url'] or '').strip()
@@ -135,7 +151,8 @@ def get_img_desc(obj):
# -> inurl:radiolab/episodes site:wnycstudios.org
# Then regex: /episodes/([^;]*?)" onmousedown
def processSingle(slug, basedir):
def processSingle(slug: str, basedir: str) -> None:
''' [internal] process single episode if only the slug is known. '''
# cms-pk = 91947 , slug = '91947-do-i-know-you'
all_slugs = [slug for _, _, _, slug in db_slugs]
if slug not in all_slugs:

View File

@@ -1,20 +1,20 @@
#!/usr/bin/env python3
from botlib.tgclient import TGClient
from botlib.tgclient import TGClient, Message
bot = TGClient(__API_KEY__, polling=True, allowedUsers=['my-username'])
@bot.message_handler(commands=['hi'])
def bot_reply(message):
def bot_reply(message: Message) -> None:
if bot.allowed(message): # only reply to a single user (my-username)
bot.reply_to(message, 'Good evening my dear.')
@bot.message_handler(commands=['set'])
def update_config(message):
def update_config(message: Message) -> None:
if bot.allowed(message):
try:
config = data_store.get(message.chat.id)
config = DATA_STORE.get(message.chat.id)
except KeyError:
bot.reply_to(message, 'Not found.')
return
@@ -28,32 +28,32 @@ def update_config(message):
@bot.message_handler(commands=['start'])
def new_chat_info(message):
def new_chat_info(message: Message) -> None:
bot.log_chat_info(message.chat)
if bot.allowed(message):
if data_store.get(message.chat.id):
if DATA_STORE.get(message.chat.id):
bot.reply_to(message, 'Already exists')
else:
CreateNew(message)
class CreateNew:
def __init__(self, message):
def __init__(self, message: Message) -> None:
self.ask_name(message)
def ask_name(self, message):
def ask_name(self, message: Message) -> None:
msg = bot.send_force_reply(message.chat.id, 'Enter Name:')
bot.register_next_step_handler(msg, self.ask_interval)
def ask_interval(self, message):
def ask_interval(self, message: Message) -> None:
self.name = message.text
msg = bot.send_buttons(message.chat.id, 'Update interval (minutes):',
options=[3, 5, 10, 15, 30, 60])
bot.register_next_step_handler(msg, self.finish)
def finish(self, message):
def finish(self, message: Message) -> None:
try:
interval = int(message.text)
interval = int(message.text or 'error')
except ValueError:
bot.send_abort_keyboard(message, 'Not a number. Aborting.')
return

View File

@@ -17,15 +17,15 @@ bot.set_on_kill(cron.stop)
def main():
def clean_db(_):
def clean_db(_) -> None:
Log.info('[clean up]')
OnceDB('cache.sqlite').cleanup(limit=150)
def notify_jobA(_):
def notify_jobA(_) -> None:
jobA.download(topic='development', cohort='dev:py')
send2telegram(__A_CHAT_ID__)
def notify_jobB(_):
def notify_jobB(_) -> None:
jobB.download()
send2telegram(__ANOTHER_CHAT_ID__)
@@ -37,14 +37,15 @@ def main():
# cron.fire()
def send2telegram(chat_id):
def send2telegram(chat_id: int) -> None:
db = OnceDB('cache.sqlite')
# db.mark_all_done()
def _send(cohort, uid, obj):
def _send(cohort: str, uid: str, obj: str) -> bool:
Log.info('[push] {} {}'.format(cohort, uid))
return bot.send(chat_id, obj, parse_mode='HTML',
disable_web_page_preview=True)
msg = bot.send(chat_id, obj, parse_mode='HTML',
disable_web_page_preview=True)
return msg is not None
if not db.foreach(_send):
# send() sleeps 45 sec (on error), safe to call immediatelly

View File

@@ -4,7 +4,7 @@ from botlib.html2list import HTML2List, MatchGroup
from botlib.oncedb import OnceDB
def download(*, topic='motherboard', cohort='vice:motherboard'):
def download(*, topic: str = 'motherboard', cohort: str = 'vice:mb') -> None:
db = OnceDB('cache.sqlite')
url = 'https://www.vice.com/en/topic/{}'.format(topic)

View File

@@ -2,19 +2,26 @@
from botlib.curl import Curl
from botlib.html2list import HTML2List, MatchGroup
from botlib.oncedb import OnceDB
from typing import Optional, Callable, TextIO
CRAIGSLIST = 'https://newyork.craigslist.org/search/boo'
def load(url):
def load(url: str) -> Optional[TextIO]:
# return open('test.html')
return Curl.get(url)
def download():
def download() -> None:
db = OnceDB('cache.sqlite')
def proc(cohort, source, select, regex={}, fn=str):
def proc(
cohort: str,
source: Optional[TextIO],
select: str,
regex: dict = {},
fn: Callable[[MatchGroup], str] = str
) -> None:
match = MatchGroup(regex)
for elem in reversed(HTML2List(select).parse(source)):
match.set_html(elem)