算額あれこれ

算額問題をコンピュータで解きます

Julia に翻訳--196 判別分析,二次の判別関数

#==========
Julia の修行をするときに,いろいろなプログラムを書き換えるのは有効な方法だ。
以下のプログラムを Julia に翻訳してみる。

判別分析(二次の判別関数)
http://aoki2.si.gunma-u.ac.jp/R/quad_disc.html

ファイル名: quaddisc.jl  関数名: quaddisc

翻訳するときに書いたメモ

==========#

using Statistics, Rmath, LinearAlgebra, NamedArrays, Plots

function quaddisc(data, group)
    vname = names(data)
    data = Matrix(data)
    n, p = size(data)
    gname, ni = table(group)
    k = length(ni)
    groupmeans = zeros(p, k)
    vars = zeros(p, p, k)
    invvars = similar(vars)
    for (i, g) in enumerate(gname)
        data2 = data[group .== g, :]
        groupmeans[:, i] = mean(data2, dims=1)
        cov2 = cov(data2)
        vars[:, :, i] = cov2
        invvars[:, :, i] = inv(cov2)
    end
    scores = zeros(n, k)
    for i = 1:n
        g = group[i]
        temp = data[i, :]
        for j = 1:k
            temp2 = temp .- groupmeans[:, j]
            scores[i, j] = temp2' * invvars[:, :, j] * temp2
        end
    end
    println("\nユークリッドの二乗距離\n", NamedArray(scores, (1:n, gname)))
    pvalues = pchisq.(scores, p, false)
    prediction = [gname[argmax(pvalues[i, :])] for i = 1:n]
    println("\n各群への所属確率\n", NamedArray(cat(pvalues, prediction, dims=2),
                                            (1:n, vcat(gname, "prediction"))))
    real, pred, correcttable = table(group, prediction)
    correctrate = sum(diag(correcttable)) / n * 100
    println("\n予測結果\n", NamedArray(correcttable, (real, pred)))
    println("\n正判別率 = $correctrate %")
    Dict(:groupmeans => groupmeans, :vars => vars, :invvars => invvars,
         :scores => scores, :pvalues => pvalues, :prediction => prediction,
         :correcttable => correcttable, :correctrate => correctrate)
end

function table(x) # indices が少ないとき
    indices = sort(unique(x))
    counts = zeros(Int, length(indices))
    for i in indexin(x, indices)
        counts[i] += 1
    end
    return indices, counts
end

function table(x, y) # 二次元
    indicesx = sort(unique(x))
    indicesy = sort(unique(y))
    counts = zeros(Int, length(indicesx), length(indicesy))
    for (i, j) in zip(indexin(x, indicesx), indexin(y, indicesy))
        counts[i, j] += 1
    end
    return indicesx, indicesy, counts
end

using RDatasets
iris = dataset("datasets", "iris")
quaddisc(iris[:, 1:4], iris[:, 5])
#=====

ユークリッドの二乗距離
150×3 Named Matrix{Float64}
A ╲ B │     setosa  versicolor   virginica
──────┼───────────────────────────────────
1     │   0.449114     114.804     182.936
2     │    2.08109     83.3154     153.975
3     │    1.28434     94.9204     160.494
4     │    1.70621     82.7788     140.641
5     │   0.761685     120.481     184.037
6     │    3.71265      120.48     183.298
7     │     3.4242     96.0035     154.019
8     │   0.343439     103.815     167.444
9     │    2.99648      73.947     132.766
10    │    3.20009     93.4959     156.739
11    │    1.89095     129.675     198.805
12    │    2.01488     99.6173     154.294
⋮                ⋮           ⋮           ⋮
139   │    489.209      8.6334     3.06747
140   │    688.018     21.4058     2.31141
141   │    808.103     45.4817     2.17213
142   │    678.615      48.093     8.31315
143   │    582.651     19.2259     1.89443
144   │    848.939     31.0498     1.33526
145   │    851.482     49.9147      3.1402
146   │    698.916      45.633     4.54549
147   │    578.018     23.3641      4.0077
148   │    618.186      16.741     1.11181
149   │    720.692     33.2029     3.94189
150   │    550.788     10.1125     2.69108

各群への所属確率
150×4 Named Matrix{Any}
A ╲ B │       setosa    versicolor     virginica    prediction
──────┼───────────────────────────────────────────────────────
1     │     0.978262   6.86992e-24   1.74568e-38      "setosa"
2     │     0.720846   3.45379e-17   2.86279e-32      "setosa"
3     │     0.864028   1.18489e-19   1.14537e-33      "setosa"
4     │      0.78959   4.48817e-17   2.05743e-29      "setosa"
5     │      0.94351   4.21616e-25   1.01264e-38      "setosa"
6     │     0.446289   4.21814e-25   1.45938e-38      "setosa"
7     │     0.489498   6.97144e-20   2.80168e-32      "setosa"
8     │      0.98684   1.51497e-21     3.698e-35      "setosa"
9     │     0.558415   3.32733e-15   9.97411e-28      "setosa"
10    │     0.524917   2.38001e-19   7.31636e-33      "setosa"
11    │     0.755807   4.56942e-27   6.78869e-42      "setosa"
12    │     0.733022   1.18664e-20   2.44609e-32      "setosa"
⋮                  ⋮             ⋮             ⋮             ⋮
139   │ 1.44521e-104     0.0709452      0.546599   "virginica"
140   │ 1.36991e-147   0.000263072      0.678692   "virginica"
141   │ 1.34966e-173    3.15701e-9      0.704135   "virginica"
142   │ 1.48777e-145   9.02582e-10     0.0807578   "virginica"
143   │ 8.80541e-125   0.000709559      0.755167   "virginica"
144   │ 1.92311e-182    2.99056e-6      0.855365   "virginica"
145   │ 5.41049e-183   3.76203e-10      0.534644   "virginica"
146   │  5.9837e-150    2.93632e-9      0.337188   "virginica"
147   │ 8.85759e-124   0.000107086      0.404965   "virginica"
148   │ 1.79562e-132    0.00217025      0.892394   "virginica"
149   │ 1.15234e-154    1.08551e-6      0.413927   "virginica"
150   │ 6.90934e-118     0.0385747      0.610776   "virginica"

予測結果
3×3 Named Matrix{Int64}
     A ╲ B │     setosa  versicolor   virginica
───────────┼───────────────────────────────────
setosa     │         50           0           0
versicolor │          0          47           3
virginica  │          0           0          50

正判別率 = 98.0 %

Dict{Symbol, Any} with 8 entries:
  :pvalues      => [0.978262 6.86992e-24 1.74568e-38; 0.720846 3.45379e-17 2.86…
  :vars         => [0.124249 0.0992163 0.0163551 0.0103306; 0.0992163 0.14369 0…
  :scores       => [0.449114 114.804 182.936; 2.08109 83.3154 153.975; … ; 720.…
  :prediction   => ["setosa", "setosa", "setosa", "setosa", "setosa", "setosa",…
  :correcttable => [50 0 0; 0 47 3; 0 0 50]
  :correctrate  => 98.0
  :invvars      => [18.9434 -12.4048 -4.50021 -4.77613; -12.4048 15.5705 1.1110…
  :groupmeans   => [5.006 5.936 6.588; 3.428 2.77 2.974; 1.462 4.26 5.552; 0.24…
=====#