# Copyright (C) 2005 JanRain, Inc.
# Copyright (C) 2009, 2010 Canonical Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from apache_openid import logging
from apache_openid.action import Action
from apache_openid.handlers.openid.mixins import ProvidersMixin
from apache_openid.utils import FieldStorage
from openid.consumer import consumer


class ReturnAction(Action, ProvidersMixin):

    def do(self):
        """Handle a response from the OpenID server. Always redirects."""
        auth_response = self.get_response(FieldStorage(self.request))
        if auth_response is None:
            self.response.login_redirect(message='failure')
        elif auth_response.status == consumer.SUCCESS:
            self.on_success(auth_response)
        elif auth_response.status == consumer.CANCEL:
            self.response.login_redirect(message='cancel')
        elif auth_response.status == consumer.FAILURE:
            self.response.login_redirect(message='failure')
        elif auth_response.status == consumer.SETUP_NEEDED:
            self.response.login_redirect()
        else:
            assert False, auth_response.status

    def get_response(self, form):
        query = {}
        for k in form.keys():
            query[k] = form.getfirst(k).decode('utf-8')
        return self.consumer.complete(query, self.request.action_url('return'))

    def on_success(self, auth_response):
        """Set the cookie and then redirect back to the target."""
        op = self.fetch_op_for_endpoint(auth_response.endpoint.server_url)
        if op is None: # Not a trusted endpoint
            self.response.login_redirect('failure')
        else:
            logging.debug("OpenID success!")
            self.complete_login(auth_response)

    def fetch_op_for_endpoint(self, endpoint):
        """Retrieve the (trusted) op that led to a given endpoint, if any."""
        store = self.consumer.consumer.store
        # If there are no allowed OPs specified, we should be OK with any
        if not self.allowed_ops:
            return True
        for op in self.allowed_ops.values():
            assoc_handle = op
            if not '://' in op:
                assoc_handle = 'http://' + op
            fake_assoc = store.getAssociation(assoc_handle)
            if fake_assoc is None:
                continue
            if fake_assoc.handle == op and fake_assoc.secret == endpoint:
                return op

    def complete_login(self, auth_response):
        self.request.cookied_user = auth_response.identity_url
        self.request.last_user = auth_response.identity_url
        target = self.session.get('target')
        if target is None:
            logging.debug('session target not set. redirecting to default')
            self.response.redirect()
        else:
            logging.debug('redirecting to session target %r', target)
            self.session['target'] = None
            self.response.redirect(target)
