]> git.gir.st - subscriptionfeed.git/blob - app/common/anticsrf.py
move anticsrf out of __init__, provide decorator for opting out
[subscriptionfeed.git] / app / common / anticsrf.py
1 # TODO: build a proper flask extension
2 # Magic CSRF protection: This modifies outgoing HTML responses and injects a csrf token into all forms.
3 # All post requests are then checked if they contain the valid token.
4 # TODO:
5 # - knobs: mimetypes, http methods, form field name, token generator
6 # - inject a http header into all responses (that could be used by apis)
7 # - allow csrf token to be passed in http header, json, ...
8 # - allow specifying hmac message contents (currently request.remote_addr)
9 import hmac
10 import hashlib
11 from flask import request, current_app
12 from werkzeug.exceptions import BadRequest
13 from html.parser import HTMLParser
14
15 def init(app):
16 app.template_global(csrf_token)
17 app.after_request(add_csrf_protection)
18 app.before_request(verify_csrf_protection)
19
20 def no_csrf_protection(func):
21 # add this decorator below @app.route
22 func._no_csrf_protection = True
23 return func
24
25 def csrf_token():
26 # TODO: will fail behind reverse proxy (remote_addr always localhost)
27 return hmac.new(current_app.secret_key, request.remote_addr.encode('ascii'), hashlib.sha256).hexdigest()
28
29 def add_csrf_protection(response):
30 if response.mimetype == "text/html":
31 csrf_elem = f'<input type="hidden" name="csrf" value="{csrf_token()}"/>'
32 new_response = add_csrf(response.get_data().decode('utf-8'), csrf_elem)
33 response.set_data(new_response.encode('utf-8'))
34 return response
35
36 def verify_csrf_protection():
37 skip = getattr(current_app.view_functions.get(request.endpoint), '_no_csrf_protection', False)
38 #^xxx: doesn't take fallback_routes into account!
39 if skip: return
40
41 if request.method == "POST" and request.form.get('csrf') != csrf_token():
42 raise BadRequest("CSRF validation failed")
43
44 request.form = request.form.copy() # make it mutable
45 request.form.poplist('csrf') # remove our csrf again
46
47 def add_csrf(html_in, csrf_elem):
48 class FindForms(HTMLParser):
49 def __init__(self, html):
50 super().__init__()
51 self.forms = [] # tuples of (line_number, tag_offset, tag_length)
52 super().feed(html)
53 def handle_starttag(self, tag, attrs):
54 line, offset = self.getpos()
55 if tag == "form" and dict(attrs).get('method','').upper() == "POST":
56 self.forms.append((line, offset, len(self.get_starttag_text())))
57 lines = html_in.splitlines(keepends=True)
58 # Note: going in reverse, to not invalidate offsets:
59 for line, offset, length in reversed(FindForms(html_in).forms):
60 l = lines[line-1]
61 lines[line-1] = l[:offset+length] + csrf_elem + l[offset+length:]
62 return "".join(lines)
Imprint / Impressum