详解Flux.jl

本文将详细介绍Julia语言中的一个深度学习库——Flux.jl,目的是在理解其内部结构之后,能在其之上做个性化定制。

核心概念

TrackedArray

TrackedArray类型用来对最基本的数组做封装。我们知道,深度学习框架带来的最大好处之一就是不用手写梯度反传的函数,其实现是基于这样一个事实,对于一类基本的函数,其梯度的计算方式是已知的,于是通过链式法则可以实现对整个网络中的每个参数进行更新。因此,一个TrackedArray类型应该至少包含

  1. 数据,即数组当前的值
  2. 映射函数,描述当前数据是根据怎样的函数(以及对应的参数)得到的,从而方便进一步反传
  3. 梯度,当前数据的梯度

然后我们看看源码中的定义

struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
  tracker::Tracked{A}
  data::A
  grad::A
  TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
  TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
end

可以看到,代码的定义与我们的直觉相符,这里的tracker字段就是用来记录data字段是怎么得到的。再具体看下Tracked{T}定义

mutable struct Tracked{T}
  ref::UInt32
  f::Call
  isleaf::Bool
  grad::T
  Tracked{T}(f::Call) where T = new(0, f, false)
  Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
  Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
end

ref先不管,isleaf用来标志当前是否是叶子节点,grad用来记录梯度(TODO:似乎跟TrackedArray中有重复?其实从数据结构上来看,需要有这么个地方做缓存。),最关键的f记录了作用的函数以及其参数,下面是Call的(定义)[https://github.com/FluxML/Flux.jl/blob/master/src/tracker/Tracker.jl#L18]:

struct Call{F,As<:Tuple}
  func::F
  args::As
end

前向计算

现在我们了解了TrackedArray的组成,但是具体怎么做前向计算的呢?

通过param构造TrackedArray

万丈高楼平地起!

在julia中,数据一般以AbstractArray的形式存在,首先,要将这类数据构造成TrackedArray,然后才能对不同的TrackedArray做前向计算。param函数就是用来构造TrackedArray的。

# https://github.com/FluxML/Flux.jl/blob/master/src/tracker/Tracker.jl#L106
param(xs::AbstractArray) = TrackedArray(float.(xs))
# https://github.com/FluxML/Flux.jl/blob/master/src/tracker/array.jl#L24
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
# https://github.com/FluxML/Flux.jl/blob/master/src/tracker/Tracker.jl#L24
Call() = Call(nothing, ())

这里用一个例子来确认下:

the param function

the param function

julia> w1 = [1 2; 3 4]
2×2 Array{Int64,2}:
 1  2
 3  4

julia> w1_tracked = param(w1)
Tracked 2×2 Array{Float64,2}:
 1.0  2.0
 3.0  4.0

julia> w1_tracked.data
2×2 Array{Float64,2}:
 1.0  2.0
 3.0  4.0

julia> w1_tracked.grad
2×2 Array{Float64,2}:
 0.0  0.0
 0.0  0.0

julia> w1_tracked.tracker.f
Flux.Tracker.Call{Void,Tuple{}}(nothing, ())

julia> w1_tracked.tracker.isleaf
true

接下来看看如何对TrackedArray做运算。在Flux里array.jl文件中,做了大量的封装工作,目的主要有:

  1. TrackedArray看作普通的AbstractArray,把系统对Array的一些操作绑定到data字段上
  2. 重载一些基本的数组运算,通过track函数将对TrackedArray的运算结果封装成新的TrackedArray

细究track函数的话,稍微有点复杂:

function track(f, xs...; kw...)
  # 前向计算,得到结果y和反向求导函数back
  y, back = _forward(f, xs...; kw...)
  # 生成新的Tracked结构
  track(Call(back, tracker.(xs)), y)
end

其中_forward函数是通过一个@grad动态定义的(这个宏稍稍有点复杂,核心是定义前向和反向求导的计算方式),在重载(或者定义)每个计算函数的时候都要声明。

macro grad(ex)
  @capture(shortdef(ex), (name_(args__) = body_) |
                         (name_(args__) where {T__} = body_)) || error("Need a function definition")
  T == nothing && (T = [])
  isexpr(name, :(::)) || (name = :(::typeof($name)))
  insert!(args, 1+isexpr(args[1], :parameters) , name)
  @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end

于是,前向计算的问题基本解决了,同时反向计算需要的偏导函数也准备好了。为了支持除了built-in计算方式之外的一些常见的函数(比如Softmax, Relu等),Flux单独开发了一个库NNlib.jl

反向传播

从代码逻辑上来讲,反向传播的实现很容易:

  1. 从后往前计算偏导并更新TrackedArraygrad字段
  2. 根据偏导更新weight

不过有些小细节需要处理:

function back!(x, Δ)
  istracked(x) || return
  scan(x)
  back(tracker(x), Δ)
  return
end

这里,scan的目的是重置整个网络中的grad

function scan(x::Tracked)
  x.isleaf && return
  ref = x.ref += 1
  if ref == 1
    scan(x.f)
    isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
  end
  return
end

可以看到,ref的作用是引用计数(这里先+=1,后面back执行的时候会-=1),反向传播的时候,会将多次计数的grad进行累加,直至计算完成后再真正执行back_:

function back(x::Tracked, Δ)
  x.isleaf && (x.grad = accum!(x.grad, Δ); return)
  ref = x.ref -= 1
  if ref > 0 || isdefined(x, :grad)
    if isdefined(x, :grad)
      x.grad = accum!(x.grad, Δ)
    else
      x.grad = Δ
    end
    ref == 0 && back_(x.f, x.grad)
  else
    ref == 0 && back_(x.f, Δ)
  end
  return
end

back_的逻辑就很简单了:

function back_(c::Call, Δ)
  Δs = c.func(Δ)
  (Δs isa Tuple && length(Δs) >= length(c.args)) ||
    error("Gradient is not a tuple of length $(length(c.args))")
  foreach(back, c.args, data.(Δs))
end

计算偏导并迭代下去。

接下来是update!

function update!(x, Δ)
  x.data .+= data(Δ)
  tracker(x).grad .= 0
  return x
end

对,如果手动更新的话,就这么简单了。

不过大多时候,都有个Optimiser,如SGD,Adam等,来辅助更新梯度。Flux在这方面没有任何特殊之处,作者用一个Param结构来管理data和Δ。

struct Param{T}
  x::T
  Δ::T
en

然后,各个Optimizer管理自己的状态,主要是通过闭包实现的。

layer

layer对一些常见的模块做了封装,如RNN和CNN等。写起来确实简单,不过,感觉需要有benchmark测试下性能。

其它

剩下的主要就是一些工具函数了,比如treelikeonehot等。