webmcp
view framework/env/auth/openid/discover.lua @ 421:c343ce9092ee
Added downward-compatibility code for mondelefant.connect{engine='postgresql', ...} call
| author | jbe | 
|---|---|
| date | Tue Jan 12 18:57:17 2016 +0100 (2016-01-12) | 
| parents | 47ddf0f86009 | 
| children | 
 line source
     1 --[[--
     2 discovery_data,                                         -- table containing "claimed_identifier", "op_endpoint" and "op_local_identifier"
     3 errmsg,                                                 -- error message in case of failure
     4 errcode =                                               -- error code in case of failure (TODO: not implemented yet)
     5 auth.openid.discover{
     6   user_supplied_identifier = user_supplied_identifier,  -- string given by user
     7   https_as_default         = https_as_default,          -- default to https
     8   curl_options             = curl_options               -- options passed to "curl" binary, when performing discovery
     9 }
    11 --]]--
    13 -- helper function
    14 local function decode_entities(str)
    15   local str = str
    16   str = string.gsub(value, "<", '<')
    17   str = string.gsub(value, ">", '>')
    18   str = string.gsub(value, """, '"')
    19   str = string.gsub(value, "&", '&')
    20   return str
    21 end
    23 -- helper function
    24 local function get_tag_value(
    25   str,          -- HTML document or document snippet
    26   match_tag,    -- tag name
    27   match_key,    -- attribute key to match
    28   match_value,  -- attribute value to match
    29   result_key    -- attribute key of value to return
    30 )
    31   -- NOTE: The following parameters are case insensitive
    32   local match_tag   = string.lower(match_tag)
    33   local match_key   = string.lower(match_key)
    34   local match_value = string.lower(match_value)
    35   local result_key  = string.lower(result_key)
    36   for tag, attributes in
    37     string.gmatch(str, "<([0-9A-Za-z_-]+) ([^>]*)>")
    38   do
    39     local tag = string.lower(tag)
    40     if tag == match_tag then
    41       local matching = false
    42       for key, value in
    43         string.gmatch(attributes, '([0-9A-Za-z_-]+)="([^"<>]*)"')
    44       do
    45         local key = string.lower(key)
    46         local value = decode_entities(value)
    47         if key == match_key then
    48           -- NOTE: match_key must only match one key of space seperated list
    49           for value in string.gmatch(value, "[^ ]+") do
    50             if string.lower(value) == match_value then
    51               matching = true
    52               break
    53             end
    54           end
    55         end
    56         if key == result_key then
    57           result_value = value
    58         end
    59       end
    60       if matching then
    61         return result_value
    62       end
    63     end
    64   end
    65   return nil
    66 end
    68 -- helper function
    69 local function tag_contents(str, match_tag)
    70   local pos = 0
    71   local tagpos, closing, tag
    72   local function next_tag()
    73     local prefix
    74     tagpos, prefix, tag, pos = string.match(
    75       str,
    76       "()<(/?)([0-9A-Za-z:_-]+)[^>]*>()",
    77       pos
    78     )
    79     closing = (prefix == "/")
    80   end
    81   return function()
    82     repeat
    83       next_tag()
    84       if not tagpos then return nil end
    85       local stripped_tag
    86       if string.find(tag, ":") then
    87         stripped_tag = string.match(tag, ":([^:]*)$")
    88       else
    89         stripped_tag = tag
    90       end
    91     until stripped_tag == match_tag and not closing
    92     local content_start = pos
    93     local used_tag = tag
    94     local counter = 0
    95     while true do
    96       repeat
    97         next_tag()
    98         if not tagpos then return nil end
    99       until tag == used_tag
   100       if closing then
   101         if counter > 0 then
   102           counter = counter - 1
   103         else
   104           return string.sub(str, content_start, tagpos-1)
   105         end
   106       else
   107         counter = counter + 1
   108       end
   109     end
   110     local content = string.sub(rest, 1, startpos-1)
   111     str = string.sub(rest, endpos+1)
   112     return content
   113   end
   114 end
   116 local function strip(str)
   117   local str = str
   118   string.gsub(str, "^[ \t\r\n]+", "")
   119   string.gsub(str, "[ \t\r\n]+$", "")
   120   return str
   121 end
   123 function auth.openid.discover(args)
   124   local url = string.match(args.user_supplied_identifier, "[^#]*")
   125   -- NOTE: XRIs are not supported
   126   if
   127     string.find(url, "^[Xx][Rr][Ii]://") or
   128     string.find(url, "^[=@+$!(]")
   129   then
   130     return nil, "XRI identifiers are not supported."
   131   end
   132   -- Prepend http:// or https://, if neccessary:
   133   if not string.find(url, "://") then
   134     if args.default_to_https then
   135       url = "https://" .. url
   136     else
   137       url = "http://" .. url
   138     end
   139   end
   140   -- Either an xrds_document or an html_document will be fetched
   141   local xrds_document, html_document
   142   -- Repeat max 10 times to avoid endless redirection loops
   143   local redirects = 0
   144   while true do
   145     local status, headers, body = auth.openid._curl(url, args.curl_options)
   146     if not status then
   147       return nil, "Error while locating XRDS or HTML file for discovery."
   148     end
   149     -- Check, if we are redirected:
   150     local location = string.match(
   151       headers,
   152       "\r?\n[Ll][Oo][Cc][Aa][Tt][Ii][Oo][Nn]:[ \t]*([^\r\n]+)"
   153     )
   154     if location then
   155       -- If we are redirected too often, then return an error.
   156       if redirects >= 10 then
   157         return nil, "Too many redirects."
   158       end
   159       -- Otherwise follow the redirection by changing the variable "url"
   160       -- and by incrementing the redirect counter.
   161       url = location
   162       redirects = redirects + 1
   163     else
   164       -- Check, if there is an X-XRDS-Location header
   165       -- pointing to an XRDS document:
   166       local xrds_location = string.match(
   167         headers,
   168         "\r?\n[Xx]%-[Xx][Rr][Dd][Ss]%-[Ll][Oo][Cc][Aa][Tt][Ii][Oo][Nn]:[ \t]*([^\r\n]+)"
   169       )
   170       -- If there is no X-XRDS-Location header, there might be an
   171       -- http-equiv meta tag serving the same purpose:
   172       if not xrds_location and status == 200 then
   173         xrds_location = get_tag_value(body, "meta", "http-equiv", "X-XRDS-Location", "content")
   174       end
   175       if xrds_location then
   176         -- If we know the XRDS-Location, we can fetch the XRDS document
   177         -- from that location:
   178         local status, headers, body = auth.openid._curl(xrds_location, args.curl_options)
   179         if not status then
   180           return nil, "XRDS document could not be loaded."
   181         end
   182         if status ~= 200 then
   183           return nil, "XRDS document not found where expected."
   184         end
   185         xrds_document = body
   186         break
   187       elseif
   188         -- If the Content-Type header is set accordingly, then we already
   189         -- should have received an XRDS document:
   190         string.find(
   191           headers,
   192           "\r?\n[Cc][Oo][Nn][Tt][Ee][Nn][Tt]%-[Tt][Yy][Pp][Ee]:[ \t]*application/xrds%+xml\r?\n"
   193         )
   194       then
   195         if status ~= 200 then
   196           return nil, "XRDS document announced but not found."
   197         end
   198         xrds_document = body
   199         break
   200       else
   201         -- Otherwise we should have received an HTML document:
   202         if status ~= 200 then
   203           return nil, "No XRDS or HTML document found for discovery."
   204         end
   205         html_document = body
   206         break;
   207       end
   208     end
   209   end
   210   local claimed_identifier   -- OpenID identifier the user claims to own
   211   local op_endpoint          -- OpenID provider endpoint URL
   212   local op_local_identifier  -- optional user identifier, local to the OpenID provider
   213   if xrds_document then
   214     -- If we got an XRDS document, we look for a matching <Service> entry:
   215     for content in tag_contents(xrds_document, "Service") do
   216       local service_uri, service_localid
   217       for content in tag_contents(content, "URI") do
   218         if not string.find(content, "[<>]") then
   219           service_uri = strip(content)
   220           break
   221         end
   222       end
   223       for content in tag_contents(content, "LocalID") do
   224         if not string.find(content, "[<>]") then
   225           service_localid = strip(content)
   226           break
   227         end
   228       end
   229       for content in tag_contents(content, "Type") do
   230         if not string.find(content, "[<>]") then
   231           local content = strip(content)
   232           if content == "http://specs.openid.net/auth/2.0/server" then
   233             -- The user entered a provider identifier, thus claimed_identifier
   234             -- and op_local_identifier will be set to nil.
   235             op_endpoint = service_uri
   236             break
   237           elseif content == "http://specs.openid.net/auth/2.0/signon" then
   238             -- The user entered his/her own identifier.
   239             claimed_identifier  = url
   240             op_endpoint         = service_uri
   241             op_local_identifier = service_localid
   242             break
   243           end
   244         end
   245       end
   246     end
   247   elseif html_document then
   248     -- If we got an HTML document, we look for matching <link .../> tags:
   249     claimed_identifier = url
   250     op_endpoint = get_tag_value(
   251       html_document,
   252       "link", "rel", "openid2.provider", "href"
   253     )
   254     op_local_identifier = get_tag_value(
   255       html_document,
   256       "link", "rel", "openid2.local_id", "href"
   257     )
   258   else
   259     error("Assertion failed")  -- should not happen
   260   end
   261   if not op_endpoint then
   262     return nil, "No OpenID endpoint found."
   263   end
   264   if claimed_identifier then
   265     claimed_identifier = auth.openid._normalize_url(claimed_identifier)
   266     if not claimed_identifier then
   267       return nil, "Claimed identifier could not be normalized."
   268     end
   269   end
   270   return {
   271     claimed_identifier  = claimed_identifier,
   272     op_endpoint         = op_endpoint,
   273     op_local_identifier = op_local_identifier
   274   }
   275 end
