function [w,h,z] = plca( x, K, iter, sz, sw, sh, z, w, h, pl, lw, lh)
% function [w,h,z] = plca( x, K, iter, sz, sw, sh, z, w, h, pl, lw, lh)
%
% Perform 2-D PLCA
%
% Inputs:
%  x     input distribution
%  K     number of components
%  iter  number of EM iterations [default = 100]
%  sz    sparsity of z [default = 0]
%  sw    sparsity of w [default = 0]
%  sh    sparsity of h [default = 0]
%  z     initial value of p(z) [default = random]
%  w     initial value of p(w) [default = random]
%  h     initial value of p(h) [default = random]
%  pl    plot flag [default = 1]
%  lw    columns of w to learn [default = 1:K]
%  lh    rows of h to learn [default = 1:K]
%
% Outputs: 
%  w   p(w) - vertical bases
%  h   p(h) - horizontal bases
%  z   p(z) - component priors
%

% Paris Smaragdis 2006-2008, paris@media.mit.edu

% 30-May-08 - Fixed update of z and problem with lh

% Get sizes
[M,N] = size( x);

% Default training iterations
if ~exist( 'iter')
	iter = 100;
end

% Default plot flag
if ~exist( 'pl')
	pl = 1;
end

% Learn w by default
if ~exist( 'lw')
	lw = 1:K;
end

% Learn h by default
if ~exist( 'lh')
	lh = 1:K;
end

% Initialize
if ~exist( 'w') || isempty( w)
	w = rand( M, K);
elseif size( w, 2) ~= K
	w = [w rand( M, K-size( w, 2))];
end
w = w ./ repmat( sum( w, 1), M, 1);

if ~exist( 'h') || isempty( h)
	h = rand( K, N);
elseif size( h, 1) ~= K
	h = [h; rand( K-size( h, 1), N)];
end
h = h ./ repmat( sum( h, 2), 1, N);

if ~exist( 'z') || isempty( z)
	z = rand(1, K);
	z = z /sum(z);
end

% Sort out sparsity parameters
if ~exist( 'sw', 'var')
	sw = 0;
end
if numel( sw) == 1
	sw = sw*ones( 1, K);
end
if size( sw, 1) == 1
	sw = repmat( sw, iter, 1);
elseif size( sw, 1) ~= iter
	sw = interp2( sw, linspace( 1, size( sw, 1), iter)', 1:K, 'linear');
end
isws = sum( sw(:));

if ~exist( 'sh', 'var')
	sh = 0;
end
if numel( sh) == 1
	sh = sh*ones( 1, K);
end
if size( sh, 1) == 1
	sh = repmat( sh, iter, 1);
elseif size( sh, 1) ~= iter
	sh = interp2( sh, linspace( 1, size( sh, 1), iter)', 1:K, 'linear');
end
ishs = sum( sh(:));

if ~exist( 'sz', 'var')
	sz = 0;
end
if numel( sz) == 1
	sz = sz*ones( 1, K);
end
if size( sz, 1) == 1
	sz = repmat( sz, iter, 1);
elseif size( sz, 1) ~= iter
	sz = interp1( sz, linspace( 1, size( sz, 1), iter), 'linear');
end
iszs = sum( sz(:));

% Iterate
tic;lt = toc;
for it = 1:iter

	% E-step
	zh = diag( z) * h;
	R = x ./ (w*zh);

	% M-step
	if ~isempty( lw)
		nw = w .* (R*zh');
	end
	if ~isempty( lh)
		nh = zh .* (w'*R);
	end
	if ~isempty( lw)
		z = sum( nw, 1);
	elseif ~isempty( lh)
		z = sum( nh, 2);
	end

	% Impose sparsity constraints
	for k = lw
		if isws
			nw(:,k) = lambert_compute_with_offset( nw(:,k), sw(it,k), 0) + eps;
		end
	end
	for k = lh
		if ishs
			nh(k,:) = lambert_compute_with_offset( nh(k,:), sh(it,k), 0) + eps;
		end
	end
	if iszs
		z = lambert_compute_with_offset( z, sz(it), 0) + eps;
	end

	% Assign and normalize
	if ~isempty( lw)
		w(:,lw) = nw(:,lw) ./ repmat( sum( nw(:,lw), 1), M, 1);
	end
	if ~isempty( lh)
		h(lh,:) = nh(lh,:) ./ repmat( sum( nh(lh,:), 2), 1, N);
	end
	z = z / sum(z);

	% Show me
	if (toc -lt > 1 || it == iter) && pl
		subplot( 2, 2, 1), imagesc( x/max(x(:))), title( num2str( it))
		subplot( 2, 2, 2), imagesc( w*diag(z)*h)
		subplot( 2, 3, 4), stem( z), axis tight
		subplot( 2, 3, 5), multiplot( w'), view( [-90 90])
		subplot( 2, 3, 6), multiplot( h)
		drawnow
		lt = toc;
	end
end


%--------------------------------------------------
function thet = lambert_compute_with_offset(omeg, z, lam_offset)
% Perform Labert's W iterations
% fixed-point iterations of eqns 14 and 19 in [1]
% [1] Brand, M. "Pattern Discovery via Entropy Minimization"
% Madhu Shashanka, <shashanka@cns.bu.edu>
% 01 Aug 2006
% ASSUMPTION: ------> z is a scalar

sz = size( omeg);
omeg = omeg(:)+eps;

oz = -omeg/z;
sgn = sign(-z);

if z>= 0
	br = -1; % the branch of Lambert Function to evaluate
	lambda = min( z * (log(z) - log(omeg) - 2) - 1) - lam_offset; % initialization
	geval = 'gidx = find( la > -745 );';  
else
	br = 0;
	lambda = - sum(omeg) - min(log(omeg));
	geval = 'gidx = find( la < 709 );';
end
lambda = lambda*ones( size( omeg));
thet = zeros( size( omeg));
for lIter = 1:2
	la = log(sgn*oz) + (1+lambda/z);
	eval(geval);
	bidx = setdiff( 1:length(omeg), gidx);
	thet(gidx) = oz(gidx) ./ lambertw_new(br, sgn*exp(la(gidx)) );
	thet(bidx) = oz(bidx) ./ lambert_arg_outof_range( la(bidx) );
	thet = thet / sum(thet);
	lambda = -omeg./thet - z.*log( thet) - z - lam_offset;
%	lambda = mean(-omeg./thet - z.*log( thet) - z - lam_offset);
end

thet = reshape( thet, sz);


%--------------------------------------------------
function w = lambert_arg_outof_range(x)
% Computes value of the lambert function W(z)
% for values of z = -exp(-x) that are outside the range of values
% of digital floating point representations.
%
% Algorithm:
% Eq (38) and Eq (39), page 23 from
% Brand, M. "Structure learning in Conditional Probability
% Models via an entropic prior and parameter extinction", 1998.
% Available at: http://www.merl.com/reports/docs/TR98-18.pdf
%
% Madhu Shashanka, <shashanka@cns.bu.edu>

w = x;
if ~isempty( x)
	while 1
		wo = w;
		w = x - log( abs( w));
		if max( abs( (w-wo)./wo)) < eps
			break
		end
	end
end
