玩转 OpenResty 协程 API,


注意:本文中列出的所有代码只是 Proof Of Concept,基本上都没有进行错误处理。另外对于一些边际情况,也可能没有考虑清楚。所以对于直接复制文中代码到项目中所造成的一切后果,请自负责任。

OK,言归正题。OpenResty 提供了以 ngx.thread.*coroutine.*ngx.semaphore 等一系列协程 API。虽然受限于 Nginx 的请求处理方式,表现力不如通用语言的协程 API 那么强大。但是开开脑洞,还是可以玩出一些花样来的。
借助这些 API,让我们尝试模拟下其他编程平台里面的调度方式。

模拟 Java 里面的 Future

Java 里的 Future 可以让我们创建一个任务,然后在需要的时候才去 get 任务的返回值。另外 Future 还有超时功能。
我们可以启用一个协程来完成具体的任务,再加一个定时结束的协程,用于实现超时。

像这样:

local function task()
    ngx.sleep(3)
    ngx.say("Done")
end

local task_thread = ngx.thread.spawn(task)
local timeout_thread = ngx.thread.spawn(function(timeout)
    ngx.sleep(timeout)
    error("timeout")
end, 2)
local ok, res = ngx.thread.wait(task_thread, timeout_thread)
if not ok then
    if res == "timeout" then
        ngx.thread.kill(task_thread)
        ngx.say("task cancelled by timeout")
        return
    end
    ngx.say("task failed, result: ", res)
end
ngx.thread.kill(timeout_thread)

注意一点,在某一协程退出之后,我们需要 kill 掉另外一个协程。因为如果没有调用 ngx.exit 之类的方法显式退出的话,一直到所有协程退出为止,当前阶段都不会结束。

引用文档里相关的内容:

By default, the corresponding Nginx handler (e.g., rewrite_by_lua handler) will not terminate until

模拟 Javascript 里面的 Promise.race/all

Promise.race/all 可以接收多个 Promise,然后打包成一个新的 Promise 返回。引用相关的文档:

The Promise.race(iterable) method returns a promise that resolves or rejects as soon as one of the promises in the iterable resolves or rejects, with the value or reason from that promise.

The Promise.all(iterable) method returns a promise that resolves when all of the promises in the iterable argument have resolved, or rejects with the reason of the first passed promise that rejects.

这里 reject 等价于协程运行中抛出 error,而 resolve 相对于协程返回了结果。这两个 API 对于 reject 的处理是一致的,都是有任一出错则立刻返回异常结果。对于正常结果,race 会在第一个结果出来时返回,而 all 则会在所有结果都出来后返回。
值得注意的是,Javascript 原生的 Promise 暂时没有 cancell 的功能。所以即使其中一个 Promise reject 了,其他 Promise 依然会继续运行。对此我们也照搬过来。

Promise.race 的实现:

local function apple()
    ngx.sleep(0.1)
    --error("apple lost")
    return "apple done"
end

local function banana()
    ngx.sleep(0.2)
    return "banana done"
end

local function carrot()
    ngx.sleep(0.3)
    return "carrot done"
end

local function race(...)
    local functions = {...}
    local threads = {}
    for _, f in ipairs(functions) do
        local th, err = ngx.thread.spawn(f)
        if not th then
            -- Promise.race 没有实现 cancell 接口,
            -- 所以我偷下懒,不管已经创建的协程了
            return nil, err
        end
        table.insert(threads, th)
    end
    local ok, res = ngx.thread.wait(unpack(threads))
    if not ok then
        return nil, res
    end
    return res
end

local res, err = race(apple, banana, carrot)
ngx.say("res: ", res, " err: ", err)
ngx.exit(ngx.OK)

Promise.all 的实现:

local function all(...)
    local functions = {...}
    local threads = {}
    for _, f in ipairs(functions) do
        local th, err = ngx.thread.spawn(f)
        if not th then
            return nil, err
        end
        table.insert(threads, th)
    end
    local res_group = {}
    for _ = 1, #threads do
        local ok, res = ngx.thread.wait(unpack(threads))
        if not ok then
            return nil, res
        end
        table.insert(res_group, res)
    end
    return res_group
end

模拟 Go 里面的 channel (仅部分实现)

再进一步,试试模拟下 Go 里面的 channel。
我们需要实现如下的语义:

这次要用到 ngx.semaphore

local semaphore = require "ngx.semaphore"

local Chan = {
    new = function(self)
        local chan_attrs = {
            _read_sema = semaphore:new(),
            _write_sema = semaphore:new(),
            _exclude_sema = semaphore:new(),
            _buffer = nil,
            _waiting_thread_num = 0,
        }
        return setmetatable(chan_attrs, {__index = self})
    end,
    send = function(self, value, timeout)
        timeout = timeout or 60
        while self._buffer do
            self._waiting_thread_num = self._waiting_thread_num + 1
            self._exclude_sema:wait(timeout)
            self._waiting_thread_num = self._waiting_thread_num - 1
        end
        self._buffer = value
        self._read_sema:post()
        self._write_sema:wait(timeout)
    end,
    receive = function(self, timeout)
        timeout = timeout or 60
        self._read_sema:wait(timeout)
        local value = self._buffer
        self._buffer = nil
        self._write_sema:post()
        if self._waiting_thread_num > 0 then
            self._exclude_sema:post()
        end
        return value
    end,
}

local chan = Chan:new()

-- 以下是使用方法
local function worker_a(ch)
    for i = 1, 10 do
        ngx.sleep(math.random() / 10)
        ch:send(i, 1)
    end
end

local function worker_c(ch)
    for i = 11, 20 do
        ngx.sleep(math.random() / 10)
        ch:send(i, 1)
    end
end

local function worker_d(ch)
    for i = 21, 30 do
        ngx.sleep(math.random() / 10)
        ch:send(i, 1)
    end
end


local function worker_b(ch)
    for _ = 1, 20 do
        ngx.sleep(math.random() / 10)
        local v = ch:receive(1)
        ngx.say("recv ", v)
    end
end

local function worker_e(ch)
    for _ = 1, 10 do
        ngx.sleep(math.random() / 10)
        local v = ch:receive(1)
        ngx.say("recv ", v)
    end
end

ngx.thread.spawn(worker_a, chan)
ngx.thread.spawn(worker_b, chan)
ngx.thread.spawn(worker_c, chan)
ngx.thread.spawn(worker_d, chan)
ngx.thread.spawn(worker_e, chan)

模拟 Buffered channel 也是可行的。

local ok, new_tab = pcall(require, "table.new")
if not ok then
    new_tab = function (_, _) return {} end
end


local BufferedChan = {
    new = function(self, buffer_size)
        if not buffer_size or buffer_size <= 0 then
            error("Invalid buffer_size " .. (buffer_size or "nil") .. " given")
        end
        local chan_attrs = {
            _read_sema = semaphore:new(),
            _write_sema = semaphore:new(),
            _waiting_thread_num = 0,
            _buffer_size = buffer_size,
        }
        chan_attrs._buffer = new_tab(buffer_size, 0)
        return setmetatable(chan_attrs, {__index = self})
    end,
    send = function (self, value, timeout)
        timeout = timeout or 60
        while #self._buffer >= self._buffer_size do
            self._waiting_thread_num = self._waiting_thread_num + 1
            self._write_sema:wait(timeout)
            self._waiting_thread_num = self._waiting_thread_num - 1
        end
        table.insert(self._buffer, value)
        self._read_sema:post()
    end,
    receive = function(self, timeout)
        timeout = timeout or 60
        self._read_sema:wait(timeout)
        local value = table.remove(self._buffer)
        if self._waiting_thread_num > 0 then
            self._write_sema:post()
        end
        return value
    end,
}

local chan = BufferedChan:new(2)
-- ...

当然上面的山寨货还是有很多问题的。比如它缺少至关重要的 select 支持,另外也没有实现 close 相关的特性。

相关内容

    暂无相关文章