]>
git.gir.st - subscriptionfeed.git/blob - app/common/anticsrf.py
1 # TODO: build a proper flask extension
2 # Magic CSRF protection: This modifies outgoing HTML responses and injects a csrf token into all forms.
3 # All post requests are then checked if they contain the valid token.
5 # - knobs: mimetypes, http methods, form field name, token generator
6 # - inject a http header into all responses (that could be used by apis)
7 # - allow csrf token to be passed in http header, json, ...
8 # - allow specifying hmac message contents (currently request.remote_addr)
11 from flask
import request
, current_app
12 from werkzeug
.exceptions
import BadRequest
13 from html
.parser
import HTMLParser
16 app
.template_global(csrf_token
)
17 app
.after_request(add_csrf_protection
)
18 app
.before_request(verify_csrf_protection
)
20 def no_csrf_protection(func
):
21 # add this decorator below @app.route
22 func
._no
_csrf
_protection
= True
26 # TODO: will fail behind reverse proxy (remote_addr always localhost)
27 return hmac
.new(current_app
.secret_key
, request
.remote_addr
.encode('ascii'), hashlib
.sha256
).hexdigest()
29 def add_csrf_protection(response
):
30 if response
.mimetype
== "text/html" and request
.endpoint
!= 'static':
31 csrf_elem
= f
'<input type="hidden" name="csrf" value="{csrf_token()}"/>'
32 new_response
= add_csrf(response
.get_data().decode('utf-8'), csrf_elem
)
33 response
.set_data(new_response
.encode('utf-8'))
36 def verify_csrf_protection():
37 skip
= getattr(current_app
.view_functions
.get(request
.endpoint
), '_no_csrf_protection', False)
38 #^xxx: doesn't take fallback_routes into account!
41 if request
.method
== "POST" and request
.form
.get('csrf') != csrf_token():
42 raise BadRequest("CSRF validation failed")
44 request
.form
= request
.form
.copy() # make it mutable
45 request
.form
.poplist('csrf') # remove our csrf again
47 def add_csrf(html_in
, csrf_elem
):
48 class FindForms(HTMLParser
):
49 def __init__(self
, html
):
51 self
.forms
= [] # tuples of (line_number, tag_offset, tag_length)
53 def handle_starttag(self
, tag
, attrs
):
54 line
, offset
= self
.getpos()
55 if tag
== "form" and dict(attrs
).get('method','').upper() == "POST":
56 self
.forms
.append((line
, offset
, len(self
.get_starttag_text())))
57 lines
= html_in
.splitlines(keepends
=True)
58 # Note: going in reverse, to not invalidate offsets:
59 for line
, offset
, length
in reversed(FindForms(html_in
).forms
):
61 lines
[line
-1] = l
[:offset
+length
] + csrf_elem
+ l
[offset
+length
:]