# TODO: build a proper flask extension # Magic CSRF protection: This modifies outgoing HTML responses and injects a csrf token into all forms. # All post requests are then checked if they contain the valid token. # TODO: # - knobs: mimetypes, http methods, form field name, token generator # - inject a http header into all responses (that could be used by apis) # - allow csrf token to be passed in http header, json, ... # - allow specifying hmac message contents (currently request.remote_addr) import hmac import hashlib from flask import request, current_app from werkzeug.exceptions import BadRequest from html.parser import HTMLParser def init(app): app.template_global(csrf_token) app.after_request(add_csrf_protection) app.before_request(verify_csrf_protection) def no_csrf_protection(func): # add this decorator below @app.route func._no_csrf_protection = True return func def csrf_token(): # TODO: will fail behind reverse proxy (remote_addr always localhost) return hmac.new(current_app.secret_key, request.remote_addr.encode('ascii'), hashlib.sha256).hexdigest() def add_csrf_protection(response): if response.mimetype == "text/html": csrf_elem = f'' new_response = add_csrf(response.get_data().decode('utf-8'), csrf_elem) response.set_data(new_response.encode('utf-8')) return response def verify_csrf_protection(): skip = getattr(current_app.view_functions.get(request.endpoint), '_no_csrf_protection', False) #^xxx: doesn't take fallback_routes into account! if skip: return if request.method == "POST" and request.form.get('csrf') != csrf_token(): raise BadRequest("CSRF validation failed") request.form = request.form.copy() # make it mutable request.form.poplist('csrf') # remove our csrf again def add_csrf(html_in, csrf_elem): class FindForms(HTMLParser): def __init__(self, html): super().__init__() self.forms = [] # tuples of (line_number, tag_offset, tag_length) super().feed(html) def handle_starttag(self, tag, attrs): line, offset = self.getpos() if tag == "form" and dict(attrs).get('method','').upper() == "POST": self.forms.append((line, offset, len(self.get_starttag_text()))) lines = html_in.splitlines(keepends=True) # Note: going in reverse, to not invalidate offsets: for line, offset, length in reversed(FindForms(html_in).forms): l = lines[line-1] lines[line-1] = l[:offset+length] + csrf_elem + l[offset+length:] return "".join(lines)