Thursday, May 24, 2012

Stochastic Gradient Decending Logistic Regression in SAS

Test the Stochastic Gradient Decending Logistic Regression in SAS. The logic and code follows the code piece of Ravi Varadhan, Ph.D from this discussion of R Help. The blog SAS Die Hard also has a post about SGD Logistic Regression in SAS.





filename foo url "http://www.biostat.jhsph.edu/~ririzarr/Teaching/754/lbw.dat" ;

data temp;
   infile foo length=len;
   input low age lwt race smoke ptl ht ui ftv bwt;
   put low age lwt race smoke ptl ht ui ftv bwt;
   if _n_>1;
run;
proc standard data=temp out=temp mean=0  std=1;
      var  age lwt smoke ht ui;
run;

proc contents data=temp  out=vars(keep=varnum  name type) noprint; run;

proc sql noprint;
      select name into :covars separated by " "
   from   vars
   where  lowcase(name) in ("age", "lwt", "smoke", "ht", "ui")
   ;
   select cats("b_", name) into :covars2 separated by " "
   from   vars
   where  lowcase(name) in ("age", "lwt", "smoke", "ht", "ui")
   ; 
     select count(*)+1 into :nparms
     from   vars
  where  lowcase(name) in ("age", "lwt", "smoke", "ht", "ui")
  ;
quit;
%put &covars2;

%lr_sgd(temp, beta, low, &covars, 
        alpha=0.01, decay=0.98, 
        criterion=0.00001, maxiter=1000);


options fullstimer;
proc logistic data=temp  outest=_beta  desc noprint;
      model low = age lwt smoke ht ui;
run;


Execution log shows:
******************************************************************************************
606


607 %lr_sgd(temp, beta, low, &covars, alpha=0.01, decay=0.98, criterion=0.00001, maxiter=1000);

MLOGIC(LR_SGD): Beginning execution.

MLOGIC(LR_SGD): Parameter DSN has value temp

MLOGIC(LR_SGD): Parameter OUTEST has value beta

MLOGIC(LR_SGD): Parameter RESPONSE has value low

MLOGIC(LR_SGD): Parameter COVARS has value age ht lwt smoke ui

MLOGIC(LR_SGD): Parameter ALPHA has value 0.01

MLOGIC(LR_SGD): Parameter DECAY has value 0.98

MLOGIC(LR_SGD): Parameter CRITERION has value 0.00001

MLOGIC(LR_SGD): Parameter MAXITER has value 1000

MPRINT(LR_SGD): options nosource nonotes;

MPRINT(LR_SGD): options nomlogic nomprint;

Iteration 1, time used 0.046, converge criteria is 0.355

Iteration 2, time used 0.032, converge criteria is 0.1886732527

Iteration 3, time used 0.14, converge criteria is 0.110486903

Iteration 4, time used 0.031, converge criteria is 0.0694367942

Iteration 5, time used 0.031, converge criteria is 0.0458381768

Iteration 6, time used 0.032, converge criteria is 0.0313858459

Iteration 7, time used 0.031, converge criteria is 0.0221141453

Iteration 8, time used 0.047, converge criteria is 0.0159488055

Iteration 9, time used 0.031, converge criteria is 0.0125402184

Iteration 10, time used 0.031, converge criteria is 0.0102548103

Iteration 11, time used 0.031, converge criteria is 0.008416673

Iteration 12, time used 0.047, converge criteria is 0.0069325467

Iteration 13, time used 0.032, converge criteria is 0.0057299824

Iteration 14, time used 0.031, converge criteria is 0.0047522323

Iteration 15, time used 0.031, converge criteria is 0.0039546116

Iteration 16, time used 0.031, converge criteria is 0.0033017935

Iteration 17, time used 0.047, converge criteria is 0.0027657536

Iteration 18, time used 0.031, converge criteria is 0.0023241894

Iteration 19, time used 0.032, converge criteria is 0.0019592976

Iteration 20, time used 0.046, converge criteria is 0.0016568231

Iteration 21, time used 0.032, converge criteria is 0.0014053171

Iteration 22, time used 0.031, converge criteria is 0.0011955575

Iteration 23, time used 0.031, converge criteria is 0.0010200936

Iteration 24, time used 0.047, converge criteria is 0.0008728877

Iteration 25, time used 0.031, converge criteria is 0.0007490333

Iteration 26, time used 0.032, converge criteria is 0.0006445313

Iteration 27, time used 0.031, converge criteria is 0.0005561129

Iteration 28, time used 0.047, converge criteria is 0.0004810987

Iteration 29, time used 0.031, converge criteria is 0.0004172864

Iteration 30, time used 0.031, converge criteria is 0.0003628606

Iteration 31, time used 0.031, converge criteria is 0.0003163211

Iteration 32, time used 0.032, converge criteria is 0.0002764247

Iteration 33, time used 0.047, converge criteria is 0.0002421384

Iteration 34, time used 0.031, converge criteria is 0.0002126018

Iteration 35, time used 0.047, converge criteria is 0.0001870963

Iteration 36, time used 0.046, converge criteria is 0.0001650204

Iteration 37, time used 0.032, converge criteria is 0.0001458692

Iteration 38, time used 0.031, converge criteria is 0.000129218

Iteration 39, time used 0.031, converge criteria is 0.0001147086

Iteration 40, time used 0.047, converge criteria is 0.0001020383

Iteration 41, time used 0.031, converge criteria is 0.0000909506

Iteration 42, time used 0.032, converge criteria is 0.0000812278

Iteration 43, time used 0.031, converge criteria is 0.0000726846

Iteration 44, time used 0.047, converge criteria is 0.0000651629

Iteration 45, time used 0.031, converge criteria is 0.0000585277

Iteration 46, time used 0.031, converge criteria is 0.0000526634

Iteration 47, time used 0.031, converge criteria is 0.0000474707

Iteration 48, time used 0.047, converge criteria is 0.0000428642

Iteration 49, time used 0.063, converge criteria is 0.0000387705

Iteration 50, time used 0.031, converge criteria is 0.0000351261

Iteration 51, time used 0.031, converge criteria is 0.000031876

Iteration 52, time used 0.032, converge criteria is 0.0000289726

Iteration 53, time used 0.031, converge criteria is 0.0000263748

Iteration 54, time used 0.031, converge criteria is 0.0000240465

Iteration 55, time used 0.031, converge criteria is 0.0000219565

Iteration 56, time used 0.031, converge criteria is 0.0000200774

Iteration 57, time used 0.047, converge criteria is 0.0000183854

Iteration 58, time used 0.172, converge criteria is 0.0000168595

Iteration 59, time used 0.031, converge criteria is 0.0000154815

Iteration 60, time used 0.032, converge criteria is 0.0000142352

Iteration 61, time used 0.031, converge criteria is 0.0000131064

Iteration 62, time used 0.031, converge criteria is 0.0000120826

Iteration 63, time used 0.031, converge criteria is 0.0000111528

Iteration 64, time used 0.032, converge criteria is 0.0000103072

Iteration 65, time used 0.031, converge criteria is 9.5372542E-6

Total Time is 2.56 sec.

Total Iteration is 65, convergence status is Converged.

At Final Iteration, max difference is 9.5372542E-6

MPRINT(LR_SGD): source;

MLOGIC(LR_SGD): Ending execution.


******************************************************************************************

The macro %LR_SGD.

/*
  SAS macro:
     Logistic Regression using Stochastic Gradient Descent.    
  Name: 
     %ls_sgd();
  Copyright (c) 2009, Liang Xie (Contact me @ xie1978 at gmail dot com)
  
  
  The SAS macro is a demonstration of an implementation of logistic 
  regression modelstrained by Stochastic Gradient Decent (SGD).This 
  program reads a training set specified as &dsn_in, trains a logistic 
  regression model, and outputs the estimated coefficients to &outest. 
  Example usage:

  %let inputdata=train_data;
  %let beta=coefficient;
  %let response=Event;
  %lr_sgd(&inputdata, &beta, &response, &covars, 
          alpha=0.008, decay=0.8, 
          criterion=0.00001, maxiter=1000);


  The following topics are not covered for simplicity:    
      - bias term    
      - regularization    
      - multiclass logistic regression (maximum entropy model)         
      - calibration of learning rate

  Distributed under GNU Affero General Public License version 3. This 
  program is free software: you can redistribute it and/or modify
  it under the terms of the GNU Affero General Public License as
  published by the Free Software Foundation, only version 3 of the
  License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  GNU Affero General Public License for more details. 

*/

%macro logistic(dsn_in, outest, response, alpha=0.0005);
proc score data=&dsn_in  score=&outest  type=parms  out=score(keep=score);
     var intercept &covars;
run;

data _xtemp/view=_xtemp;
     merge &dsn score ;
     _w=&response - 1/(1+exp(-score));
  /*
  array x{*} intercept &covars;
  _w=&response - 1/(1+exp(-score));
  do i=1 to dim(x); x[i]=x[i]*_w; end;    
  */
run;

data _x&outest;
  array x{*} intercept &covars;
  array _x{*} b_intercept &covars2;
  retain b_intercept &covars2;
  retain logneg logpos 0;
  modify _x&outest;
  do i=1 to dim(x); x[i]=_x[i]; end;
  
  do until (eof);
        set _xtemp  end=eof;
     do  i=1 to dim(x);
         _x[i]=_x[i]+&alpha*x[i]*_w;
     end;
  end;
     replace;
run;

%mend;

%macro compare(dsn1, dsn2);
data _null_;
     merge &dsn1  &dsn2;
  array _x1{*} intercept &covars;
  array _x2{*} b_intercept &covars2;
  retain maxdiff 0;
  do i=1 to dim(_x1);      
     maxdiff=max(maxdiff, abs(_x1[i]-_x2[i]));  
  *put _x1[*]=;
  *put _x2[*]=;
  end;
  call symput('maxdiff', maxdiff);
run;
%mend;




%macro lr_sgd(dsn, outest, response, covars, 
              alpha=0.0005, decay=0.9, 
              criterion=0.00001, maxiter=1000);
options nosource nonotes;
options nomlogic nomprint;
%local i t0 t1 dt maxdiff status  stopiter;

%let t00=%sysfunc(datetime());

data &dsn;
     set &dsn;
  intercept=1; _w=1;
run;

data &outest;
     retain _TYPE_ "PARMS"  _NAME_ "SCORE";
     array x{*} intercept &covars;
  do i=1 to dim(x); 
     x[i]=0;
  end;
  drop i;
  output;
run;

data _x&outest;
     retain _TYPE_ "PARMS"  _NAME_ "SCORE";
     array bx{*} b_intercept &covars2;
  array x{*}  intercept &covars;
  set &outest;
  do j=1 to dim(x); bx[j]=x[j]; end;
  keep b_intercept &covars2 _TYPE_  _NAME_;
  drop j;
run;

sasfile _x&outest load;
%let stopiter=&maxiter;
%let status=Not Converged.;
%do i=1 %to &maxiter;
    %let t0=%sysfunc(datetime());

    %logistic(&dsn, &outest, &response, alpha=&alpha);
    %compare(&outest, _x&outest);    
    data &outest;
         retain _TYPE_ "PARMS"  _NAME_ "SCORE";
         array bx{*} b_intercept &covars2;
      array x{*}  intercept &covars;
      set _x&outest;
      do j=1 to dim(x); x[j]=bx[j]; end;
      keep intercept &covars _TYPE_  _NAME_;
      drop j;
    run;
 %let alpha=%sysevalf(&alpha * &decay);
 %let alpha=%sysfunc(max(0.00005, &alpha));
    %let t1=%sysfunc(datetime());
    %let dt=%sysfunc(round(&t1-&t0, 0.001));
    %put Iteration &i, time used &dt, converge criteria is &maxdiff; 
 %if %sysevalf(&maxdiff<&criterion) %then %do;
     %let stopiter=&i;
        %let i=%eval(&maxiter+1);
  %let status=Converged.;
 %end;
%end;
sasfile _x&outest close;
%let t11=%sysfunc(datetime());
%let dt=%sysfunc(round(&t11-&t00, 0.01));
%put Total Time is &dt sec.;
%put Total Iteration is &stopiter, convergence status is &status;
%put At Final Iteration, max difference is &maxdiff;
options mlogic mprint notes source;
%mend;