webmcp
view libraries/mondelefant/mondelefant_atom_connector.lua @ 565:4e5d8d6c0d7c
Updated LICENSE file
| author | jbe | 
|---|---|
| date | Wed Apr 28 12:53:54 2021 +0200 (2021-04-28) | 
| parents | c839cbd66598 | 
| children | 
 line source
     1 #!/usr/bin/env lua
     3 local _G             = _G
     4 local _VERSION       = _VERSION
     5 local assert         = assert
     6 local error          = error
     7 local getmetatable   = getmetatable
     8 local ipairs         = ipairs
     9 local next           = next
    10 local pairs          = pairs
    11 local print          = print
    12 local rawequal       = rawequal
    13 local rawget         = rawget
    14 local rawlen         = rawlen
    15 local rawset         = rawset
    16 local select         = select
    17 local setmetatable   = setmetatable
    18 local tonumber       = tonumber
    19 local tostring       = tostring
    20 local type           = type
    22 local math      = math
    23 local string    = string
    24 local table     = table
    26 local mondelefant = require("mondelefant")
    27 local atom        = require("atom")
    28 local json        = require("json")
    30 local _M = {}
    31 if _ENV then
    32   _ENV = _M
    33 else
    34   _G[...] = _M
    35   setfenv(1, _M)
    36 end
    39 input_converters = setmetatable({}, { __mode = "k" })
    41 input_converters["boolean"] = function(conn, value, rawtext_mode)
    42   if rawtext_mode then
    43     if value then return "t" else return "f" end
    44   else
    45     if value then return "TRUE" else return "FALSE" end
    46   end
    47 end
    49 input_converters["number"] = function(conn, value, rawtext_mode)
    50   if _VERSION == "Lua 5.2" then
    51     -- TODO: remove following compatibility hack to allow large integers (e.g. 1e14) in Lua 5.2
    52     local integer_string = string.format("%i", value)
    53     if tonumber(integer_string) == value then
    54       return integer_string
    55     else
    56       local number_string = tostring(value)
    57       if string.find(number_string, "^[0-9.e+-]+$") then
    58         return number_string
    59       else
    60         if rawtext_mode then return "NaN" else return "'NaN'" end
    61       end
    62     end
    63   end
    64   local integer = math.tointeger(value)
    65   if integer then
    66     return tostring(integer)
    67   end
    68   local str = tostring(value)
    69   if string.find(str, "^[0-9.e+-]+$") then
    70     return str
    71   end
    72   if rawtext_mode then return "NaN" else return "'NaN'" end
    73 end
    75 input_converters[atom.fraction] = function(conn, value, rawtext_mode)
    76   if value.invalid then
    77     if rawtext_mode then return "NaN" else return "'NaN'" end
    78   else
    79     local n, d = tostring(value.numerator), tostring(value.denominator)
    80     if string.find(n, "^%-?[0-9]+$") and string.find(d, "^%-?[0-9]+$") then
    81       if rawtext_mode then
    82         return n .. "/" .. d
    83       else
    84         return "(" .. n .. "::numeric / " .. d .. "::numeric)"
    85       end
    86     else
    87       if rawtext_mode then return "NaN" else return "'NaN'" end
    88     end
    89   end
    90 end
    92 input_converters[atom.date] = function(conn, value, rawtext_mode)
    93   if rawtext_mode then
    94     return tostring(value)
    95   else
    96     return conn:quote_string(tostring(value)) .. "::date"
    97   end
    98 end
   100 input_converters[atom.timestamp] = function(conn, value, rawtext_mode)
   101   if rawtext_mode then
   102     return tostring(value)
   103   else
   104     return conn:quote_string(tostring(value))  -- don't define type
   105   end
   106 end
   108 input_converters[atom.time] = function(conn, value, rawtext_mode)
   109   if rawtext_mode then
   110     return tostring(value)
   111   else
   112     return conn:quote_string(tostring(value)) .. "::time"
   113   end
   114 end
   116 input_converters["rawtable"] = function(conn, value, rawtext_mode)
   117   -- treat tables as arrays
   118   local parts = { "{" }
   119   for i, v in ipairs(value) do
   120     if i > 1 then parts[#parts+1] = "," end
   121     local converter =
   122       input_converters[getmetatable(v)] or
   123       input_converters[type(v)]
   124     if converter then
   125       v = converter(conn, v, true)
   126     else
   127       v = tostring(v)
   128     end
   129     parts[#parts+1] = '"'
   130     parts[#parts+1] = string.gsub(v, '[\\"]', '\\%0')
   131     parts[#parts+1] = '"'
   132   end
   133   parts[#parts+1] = "}"
   134   return conn:quote_string(table.concat(parts))
   135 end
   138 output_converters = setmetatable({}, { __mode = "k" })
   140 output_converters.int8 = function(str) return atom.integer:load(str) end
   141 output_converters.int4 = function(str) return atom.integer:load(str) end
   142 output_converters.int2 = function(str) return atom.integer:load(str) end
   144 output_converters.numeric = function(str) return atom.number:load(str) end
   145 output_converters.float4  = function(str) return atom.number:load(str) end
   146 output_converters.float8  = function(str) return atom.number:load(str) end
   148 output_converters.bool = function(str) return atom.boolean:load(str) end
   150 output_converters.date = function(str) return atom.date:load(str) end
   152 local function timestamp_loader_func(str)
   153   local year, month, day, hour, minute, second = string.match(
   154     str,
   155     "^([0-9][0-9][0-9][0-9])%-([0-9][0-9])%-([0-9][0-9]) ([0-9]?[0-9]):([0-9][0-9]):([0-9][0-9])"
   156   )
   157   if year then
   158     return atom.timestamp{
   159       year   = tonumber(year),
   160       month  = tonumber(month),
   161       day    = tonumber(day),
   162       hour   = tonumber(hour),
   163       minute = tonumber(minute),
   164       second = tonumber(second)
   165     }
   166   else
   167     return atom.timestamp.invalid
   168   end
   169 end
   170 output_converters.timestamp = timestamp_loader_func
   171 output_converters.timestamptz = timestamp_loader_func
   173 local function time_loader_func(str)
   174   local hour, minute, second = string.match(
   175     str,
   176     "^([0-9]?[0-9]):([0-9][0-9]):([0-9][0-9])"
   177   )
   178   if hour then
   179     return atom.time{
   180       hour   = tonumber(hour),
   181       minute = tonumber(minute),
   182       second = tonumber(second)
   183     }
   184   else
   185     return atom.time.invalid
   186   end
   187 end
   188 output_converters.time = time_loader_func
   189 output_converters.timetz = time_loader_func
   191 local function json_loader_func(str)
   192   return assert(json.import(str))
   193 end
   194 output_converters.json = json_loader_func
   195 output_converters.jsonb = json_loader_func
   197 mondelefant.postgresql_connection_prototype.type_mappings = {
   198   int8 = atom.integer,
   199   int4 = atom.integer,
   200   int2 = atom.integer,
   201   bool = atom.boolean,
   202   date = atom.date,
   203   timestamp = atom.timestamp,
   204   time = atom.time,
   205   text = atom.string,
   206   varchar = atom.string,
   207   json = json,
   208   jsonb = json,
   209 }
   212 function mondelefant.postgresql_connection_prototype.input_converter(conn, value, info)
   213   if value == nil then
   214     return "NULL"
   215   else
   216     local mt = getmetatable(value)
   217     local converter = input_converters.mt
   218     if not converter then
   219       local t = type(value)
   220       if t == "table" and mt == nil then
   221         converter = input_converters.rawtable
   222       else
   223         converter = input_converters.t
   224       end
   225     end
   226     local converter =
   227       input_converters[getmetatable(value)] or
   228       input_converters[type(value)]
   229     if converter then
   230       return converter(conn, value)
   231     else
   232       return conn:quote_string(tostring(value))
   233     end
   234   end
   235 end
   237 function mondelefant.postgresql_connection_prototype.output_converter(conn, value, info)
   238   if value == nil then
   239     return nil
   240   else
   241     local array_type = nil
   242     if info.type then
   243       array_type = string.match(info.type, "^(.*)_array$")
   244     end
   245     if array_type then
   246       local result = {}
   247       local count = 0
   248       local inner_converter = output_converters[array_type]
   249       if not inner_converter then
   250         inner_converter = function(x) return x end
   251       end
   252       value = string.match(value, "^{(.*)}$")
   253       if not value then
   254         error("Could not parse database array")
   255       end
   256       local pos = 1
   257       while pos <= #value do
   258         count = count + 1
   259         if string.find(value, '^""$', pos) then
   260           result[count] = inner_converter("")
   261           pos = pos + 2
   262         elseif string.find(value, '^"",', pos) then
   263           result[count] = inner_converter("")
   264           pos = pos + 3
   265         elseif string.find(value, '^"', pos) then
   266           local p1, p2, entry = string.find(value, '^"(.-[^\\])",', pos)
   267           if not p1 then
   268             p1, p2, entry = string.find(value, '^"(.*[^\\])"$', pos)
   269           end
   270           if not entry then error("Could not parse database array") end
   271           entry = string.gsub(entry, "\\(.)", "%1")
   272           result[count] = inner_converter(entry)
   273           pos = p2 + 1
   274         else
   275           local p1, p2, entry = string.find(value, '^(.-),', pos)
   276           if not p1 then p1, p2, entry = string.find(value, '^(.*)$', pos) end
   277           result[count] = inner_converter(entry)
   278           pos = p2 + 1
   279         end
   280       end
   281       return result
   282     else
   283       local converter = output_converters[info.type]
   284       if converter then
   285         return converter(value)
   286       else
   287         return value
   288       end
   289     end
   290   end
   291 end
   294 function mondelefant.save_mutability_state(value)
   295   local jsontype = json.type(value)
   296   if jsontype == "object" or jsontype == "array" then
   297     return tostring(value)
   298   end
   299 end
   301 function mondelefant.verify_mutability_state(value, state)
   302   return tostring(value) ~= state
   303 end
   306 return _M
   309 --[[
   311 db = assert(mondelefant.connect{engine='postgresql', dbname='test'})
   312 result = db:query{'SELECT ? + 1', atom.date{ year=1999, month=12, day=31}}
   313 print(result[1][1].year)
   315 --]]
