function [w,h,z,xa2] = cplca( x, K, T, z, w, h, iter, sz, sw, sh, lz, lw, lh, pl)
% function [w,h,z,xa] = cplca( x, K, T, z, w, h, iter, sz, sw, sh, lz, lw, lh, pl)
%
% Perform convolutive PLCA with two factors in arbitrary dimensions
%  model is x = sum_i z(i)*conv( w{i}, h{i})
%
% Inputs:
%  x     input distribution
%  K     number of components
%  T     size of components
%  z     initial value of p(z) [default = random]
%  w     initial value of p(w) [default = random]
%  h     initial value of p(h) [default = random]
%  iter  number of EM iterations [default = 100]
%  sz    sparsity of z [default = 0]
%  sw    sparsity of w [default = 0]
%  sh    sparsity of h [default = 0]
%  lz    flag to estimate z [default = 1]
%  lw    flag to estimate w [default = 1]
%  lh    flag to estimate h [default = 1]
%  pl    plot flag [default = 1]
%
% Outputs: 
%  w   p(w) - vertical bases
%  h   p(h) - horizontal bases
%  z   p(z) - component priors
%  xa  approximation of input

% Paris Smaragdis 2006-2008, paris@media.mit.edu

d = length(T);  % number of dimensions in the input (and the kernel)

% Sort out the sizes
wc = 2*size(x)-T;
hc = size(x)+T-1;

% Default training iterations
if ~exist( 'iter')
	iter = 100;
end

% Default plot flag
if ~exist( 'pl')
	pl = 1;
end

% Initialize
if ~exist( 'w') || isempty( w)
	for k = 1:K
		w{k} = rand( T);
		w{k} = w{k} / sum( w{k}(:));
	end
end
if ~exist( 'h') || isempty( h)
	for k = 1:K
		h{k} = rand( size(x)-T+1);
		h{k} = h{k} / sum(h{k}(:));
	end
end
if ~exist( 'z') || isempty( z)
	z = rand(1, K);
	z = z /sum(z);
end

% Sort out sparsity parameters
if numel( sw) == 1
	sw = sw*ones( 1, K);
end
if size( sw, 1) == 1
	sw = repmat( sw, iter, 1);
elseif size( sw, 1) ~= iter
	sw = interp2( sw, linspace( 1, size( sw, 1), iter)', 1:K, 'linear');
end
isws = sum( sw(:));

if numel( sh) == 1
	sh = sh*ones( 1, K);
end
if size( sh, 1) == 1
	sh = repmat( sh, iter, 1);
elseif size( sh, 1) ~= iter
	sh = interp2( sh, linspace( 1, size( sh, 1), iter)', 1:K, 'linear');
end
ishs = sum( sh(:));

if numel( sz) == 1
	sz = sz*ones( 1, K);
end
if size( sz, 1) == 1
	sz = repmat( sz, iter, 1);
elseif size( sz, 1) ~= iter
	sz = interp1( sz, linspace( 1, size( sz, 1), iter), 'linear');
end
iszs = sum( sz(:));

% Make commands for subsequent multidim operations
fnw = 'c(';
fnh = 'c(';
xai = 'xa(';
flz = 'xbar(';
mcolons = '(';
for i = 1:d
	fnw = [fnw sprintf( 'size(x,%d)-(1:T(%d))+1,', i, i)];
	fnh = [fnh sprintf( 'hc(%d)-(T(%d)+(1:size(h{k},%d))-2),', i, i, i)];
	xai = [xai sprintf( '1:size(x,%d),', i)];
	flz = [flz sprintf( 'end:-1:1,')];
end
fnw = [fnw(1:end-1) ')'];
fnh = [fnh(1:end-1) ')'];
xai = [xai(1:end-1) ')'];
flz = [flz(1:end-1) ')'];

% Iterate
tic;lt = toc;
for it = 1:iter

	% E-step
	xa = eps;
	for k = 1:K
		fh{k} = fftn( h{k}, wc);
		fw{k} = fftn( w{k}, wc);
		xa = xa + z(k)*abs( real( ifftn( fw{k} .* fh{k})));
	end
	xa = eval( xai);
	xbar = x ./ xa;
	xbar = eval( flz);
	fx = fftn( xbar, wc);

	% M-step
	for k = 1:K

		% Update W
		if lw
			c = abs( real( ifftn( fx .* fh{k})));
			nw = eval( fnw);
			nw = nw .* w{k};
		end

		% Update H
		if lh
			c = abs( real( ifftn( fx .* fw{k})));
			nh = eval( fnh);
			nh = nh .* h{k};
		end

		% Impose sparsity constraints
		if isws & lw
			nw = lambert_compute_with_offset( nw, sw(it,k), 0) + eps;
		end
		if ishs & lh
			nh = lambert_compute_with_offset( nh, sh(it,k), 0) + eps;
		end
 
		% Update z
		if lz
			if lw
				z(k) = sum( nw(:));
			else
				z(k) = sum( nh(:));
			end
		end

		% Assign and normalize
		if lw
			w{k} = nw / sum(nw(:));
		end
		if lh
			h{k} = nh / sum(nh(:));
		end
	end

	% Sparsity for z
	if lz
		if iszs
			z = lambert_compute_with_offset( z, sz(it), 0) + eps;
		end
		z = z / sum(z);
	end

	% Show me
	if (toc -lt > 1 || it == iter) && pl
		subplot( 3, 3, 1), imagesc( reshape( x, size( x, 1), [])/max(x(:))), title( num2str( it))
		subplot( 3, 3, 2), imagesc( reshape( xa, size( xa, 1), [])/max(xa(:)))
		subplot( 3, 3, 3), stem( z), axis tight
		for k = 1:K
			subplot( 3, K, K+k), imagesc( squeeze( w{k})/max(w{k}(:))), title( num2str( -sum(w{k}(:).*log(w{k}(:)+eps))))
			set( gca, 'xtick', [], 'ytick', []);
			subplot( 3, K, 2*K+k), imagesc( squeeze(h{k})/max(h{k}(:))), title( num2str( -sum(h{k}(:).*log(h{k}(:)+eps))))
			set( gca, 'xtick', [], 'ytick', []);
		end
		drawnow
		lt = toc;
	end        
end

% Make reconstruction
if nargout == 4
%	xa2 = 0;
	for k = 1:K
		fh = fftn( h{k}, wc);
		fw = fftn( w{k}, wc);
		xa = z(k)*abs( real( ifftn( fw .* fh)));
		xa = eval( xai);
		xa2{k} = xa;
	end
end


%--------------------------------------------------
function thet = lambert_compute_with_offset(omeg, z, lam_offset)
% Perform Labert's W iterations
% fixed-point iterations of eqns 14 and 19 in [1]
% [1] Brand, M. "Pattern Discovery via Entropy Minimization"
% Madhu Shashanka, <shashanka@cns.bu.edu>
% 01 Aug 2006
% ASSUMPTION: ------> z is a scalar

sz = size( omeg);
omeg = omeg(:)+eps;

oz = -omeg/z;
sgn = sign(-z);

if z>= 0
	br = -1; % the branch of Lambert Function to evaluate
	lambda = min( z * (log(z) - log(omeg) - 2) - 1) - lam_offset; % initialization
	geval = 'gidx = find( la > -745 );';  
else
	br = 0;
	lambda = - sum(omeg) - min(log(omeg));
	geval = 'gidx = find( la < 709 );';
end
lambda = lambda*ones( size( omeg));
thet = zeros( size( omeg));
for lIter = 1:2
	la = log(sgn*oz) + (1+lambda/z);
	eval(geval);
	bidx = setdiff( 1:length(omeg), gidx);
	thet(gidx) = oz(gidx) ./ lambertw_new(br, sgn*exp(la(gidx)) );
	thet(bidx) = oz(bidx) ./ lambert_arg_outof_range( la(bidx) );
	thet = thet / sum(thet);
	lambda = -omeg./thet - z.*log( thet) - z - lam_offset;
%	lambda = mean(-omeg./thet - z.*log( thet) - z - lam_offset);
end

thet = reshape( thet, sz);


%--------------------------------------------------
function w = lambert_arg_outof_range(x)
% Computes value of the lambert function W(z)
% for values of z = -exp(-x) that are outside the range of values
% of digital floating point representations.
%
% Algorithm:
% Eq (38) and Eq (39), page 23 from
% Brand, M. "Structure learning in Conditional Probability
% Models via an entropic prior and parameter extinction", 1998.
% Available at: http://www.merl.com/reports/docs/TR98-18.pdf
%
% Madhu Shashanka, <shashanka@cns.bu.edu>

w = x;
if ~isempty( x)
	while 1
		wo = w;
		w = x - log( abs( w));
		if max( abs( (w-wo)./wo)) < eps
			break
		end
	end
end
