Section 33.4 Nelder-Mead method in higher dimensions
This is an optional section, feel free to skip it.
The code in Example 33.3.1 works only for functions of two variables but it can be adapted for \(n\)-dimensional optimization. Instead of the initial triangle
A = randn(2, 1); B = A + [1; 0]; C = A + [0; 1]; T = [A B C];
we create initial \(n\)-dimensional simplex using implicit array expansion:
A = randn(n, 1); T = [A A+eye(n)];
The search path will be path = zeros(n, max_tries)
which we will only use for visualization when \(n=2, 3\text{.}\)
The line values = [f(T(:,1)) f(T(:,2)) f(T(:,3))]
resists vectorization unless we rewrite f
(which we probably should not do because in practice, the objective function is user-provided code). So it becomes a loop:
values = zeros(1, n+1); for j = 1:n+1 values(j) = f(T(:, j)); end
The remaining adjustments are small: the midpoint of opposite side is now M = (sum(T, 2) - T(:, ind))/n
but the formulas for reflection \(R\text{,}\) expansion \(E\text{,}\) and contraction remain the same. The plot of search path and formatting of text output need cosmetic changes, as shown in the following example.
Example 33.4.1. Implementing the Nelder-Mead method in higher dimensions.
Generalize Example 33.3.1 to higher dimensions. Use it to minimize the 3-variable Rosenbrock function \((x_1-1)^2 + 100(x_1^2-x_2)^2 + 100(x_2^2-x_3)^2\text{.}\)
The code collects the lines from the previous paragraphs.
f = @(x) (x(1)-1)^2 + 100*(x(1)^2 - x(2))^2 + 100*(x(2)^2 - x(3))^2; n = 3; % number of variables A = randn(n, 1); T = [A, A+eye(n)]; max_tries = 10000; path = zeros(n, max_tries); for k = 1:max_tries path(:, k) = mean(T, 2); if max(abs(T - mean(T, 2))) < 1e-6 break end values = zeros(1, n+1); for j = 1:n+1 values(j) = f(T(:, j)); end [fmax, ind] = max(values); M = (sum(T, 2) - T(:, ind))/n; R = 2*M - T(:, ind); if f(R) < fmax E = 2*R - M; if f(E) < f(R) T(:, ind) = E; else T(:, ind) = R; end else [fmin, ind] = min(values); T = (T + T(:, ind))/2; end end if n == 2 plot(path(1, 1:k), path(2, 1:k), '-+') elseif n == 3 plot3(path(1, 1:k), path(2, 1:k), path(3, 1:k), '-+') end if k < max_tries x = mean(T, 2); fprintf('Found x with f(x) = %g after %d steps\n x =', f(x), k); disp(x'); else disp('Failed to converge') end