Announcement

Collapse
No announcement yet.
X
  • Filter
  • Time
  • Show
Clear All
new posts

  • How can I calculate only the diagonal elements of a matrix?

    Hi,
    I have a matrix (A) [i x j] and matrix B [j x j] where i is large (~80,000). I want only the diagonal terms of A*B*A' in an n-vector. Currently I use:

    AB_var = diagonal(A * B * A')

    This takes a very long time because it's also computing the non-diagonal entries, then extracting the diagonal. It seems inefficient to compute the whole ABA' matrix when I only need the diagonals.
    Any thoughts?

    Thanks guys!

    -Eric




  • #2
    rowsum((A:*(B*A')')) appears to be faster. See example below:

    Code:
    . mata:
    ------------------------------------------------- mata (type end to exit) ---------------------------------------------------------
    : rseed(504)
    
    : A = runiform(5,5)
    
    : B = runiform(5,5)
    
    : 
    : 
    : AB = diagonal(A * B * A') 
    
    : AB
                     1
        +---------------+
      1 |  5.124582983  |
      2 |  7.965222508  |
      3 |  2.272116075  |
      4 |  2.576442027  |
      5 |  1.702678092  |
        +---------------+
    
    : 
    : C= rowsum((A:*(B*A')'))
    
    : C
                     1
        +---------------+
      1 |  5.124582983  |
      2 |  7.965222508  |
      3 |  2.272116075  |
      4 |  2.576442027  |
      5 |  1.702678092  |
        +---------------+
    
    : 
    : end
    -----------------------------------------------------------------------------------------------------------------------------------
    Code:
    . mata:
    ------------------------------------------------- mata (type end to exit) ---------------------------------------------------------
    : timer_clear()
    
    : rseed(504)
    
    : A = runiform(1000,1000)
    
    : B = runiform(1000,1000)
    
    : 
    : timer_on(1)
    
    : AB = diagonal(A * B * A') 
    
    : timer_off(1)
    
    : 
    : timer_on(2)
    
    : C= rowsum((A:*(B*A')'))
    
    : timer_off(2)
    
    : 
    : timer()
    
    -----------------------------------------------------------------------------------------------------------------------------------
    timer report
      1.        1.1 /        1 =       1.1
      2.       .635 /        1 =      .635
    -----------------------------------------------------------------------------------------------------------------------------------
    
    : end
    -----------------------------------------------------------------------------------------------------------------------------------
    
    .

    Comment


    • #3
      Scott Merryman I had a similar problem (calculating standard errors of linear predictors). Your formula is super useful!

      (My solution was to loop over the rows of A. Given the size of the matrices I work with (large "i", small "j"), this solution is still more efficient than calculating A*B*A'. However, your formula wins hands down. Thanks!)

      Code:
      . mata
      ------------------------------------------------- mata (type end to exit) -----------------------------------------------------
      :  rseed(504)
      
      :  timer_clear()
      
      : 
      :  A = runiform(10000,10)
      
      :  B = runiform(10,10)
      
      :  
      :  timer_on(1)
      
      :  AB = diagonal(A * B * A') 
      
      :  timer_off(1)
      
      :  
      :  timer_on(2)
      
      :  C= rowsum((A:*(B*A')'))
      
      :  timer_off(2)
      
      : 
      :  timer_on(3)
      
      :  D = J(rows(A),1,.)
      
      :  for (i=1; i<=rows(A); i++) {
      >          D[i,1] = A[i,] * B * A[i,]'
      > }
      
      : timer_off(3)
      
      : 
      : AB[1..10,1]
                        1
           +---------------+
         1 |  22.78286966  |
         2 |  10.08931292  |
         3 |  8.201109155  |
         4 |  9.932160502  |
         5 |  13.70261162  |
         6 |  8.312025624  |
         7 |   11.7587381  |
         8 |  9.240110691  |
         9 |  18.29079768  |
        10 |  18.21981088  |
           +---------------+
      
      : C[1..10,1]
                        1
           +---------------+
         1 |  22.78286966  |
         2 |  10.08931292  |
         3 |  8.201109155  |
         4 |  9.932160502  |
         5 |  13.70261162  |
         6 |  8.312025624  |
         7 |   11.7587381  |
         8 |  9.240110691  |
         9 |  18.29079768  |
        10 |  18.21981088  |
           +---------------+
      
      : D[1..10,1]
                        1
           +---------------+
         1 |  22.78286966  |
         2 |  10.08931292  |
         3 |  8.201109155  |
         4 |  9.932160502  |
         5 |  13.70261162  |
         6 |  8.312025624  |
         7 |   11.7587381  |
         8 |  9.240110691  |
         9 |  18.29079768  |
        10 |  18.21981088  |
           +---------------+
      
      : 
      : timer()
      
      -------------------------------------------------------------------------------------------------------------------------------
      timer report
        1.       .562 /        1 =      .562
        2.       .002 /        1 =      .002
        3.        .02 /        1 =       .02
      -------------------------------------------------------------------------------------------------------------------------------
      
      : end
      -------------------------------------------------------------------------------------------------------------------------------

      Comment


      • #4
        Scott Merryman This is exactly what I was looking for. The code runs in seconds rather than in the 15-20 min range. Thank you so much!

        Comment


        • #5
          Scott Merryman Great solution!
          I thought that transposing would be a costly operation in that context, but it appears not to be the case.
          Note that (B*A')' = A*B'. So we save one transpose. Sometimes it is marginally faster and sometimes slower. Strangely though
          Code:
          rowsum(A:*(A*B))
          gives the same numerical result. It saves two transposes but is slighltly slower.
          Note that the gain in speed increases as A gets close to square.
          Any idea why it is the case?

          Code:
          : N = 1000
          
          : k = 1000
          
          : M = 1
          
          : 
          : timer_clear()
          
          : 
          : rseed(504)
          
          : 
          : A = runiform(k,N)
          
          : 
          : B = runiform(N,N)
          
          : 
          :  
          :  timer_on(1)
          
          : 
          : for (i = 1; i<= M ;i++) AB = diagonal(A * B * A') 
          
          : 
          : timer_off(1)
          
          : 
          : 
          : timer_on(2)
          
          : 
          : for (i = 1; i<= M ;i++) C= rowsum((A:*(B*A')'))
          
          : 
          : timer_off(2)
          
          : 
          : 
          : timer_on(3)
          
          : 
          : for (i = 1; i<= M ;i++) D =  rowsum(A:*(A*B)) // rowsum(A:*(A*B'))
          
          : 
          : timer_off(3)
          
          : 
          : // rowsum(A:*(A*diag(diagonal(B))))
          : 
          : // AB
          : sum(AB:!=C)
            943
          
          : sum(C:!=D)
            943
          
          : sum(AB:!=D)
            0
          
          : mreldif(AB,C)
            4.80474e-15
          
          : mreldif(AB,D)
            0
          
          : 
          : timer()
          
          -------------------------------------------------------------------------------------------------------
          timer report
            1.       1.95 /        1 =     1.952
            2.        .98 /        1 =       .98
            3.       .985 /        1 =      .985
          -------------------------------------------------------------------------------------------------------

          Comment


          • #6
            Originally posted by Christophe Kolodziejczyk View Post
            I thought that transposing would be a costly operation in that context, but it appears not to be the case.
            ...
            Any idea why it is the case?
            [/CODE]
            My guess is that this would depend in how the matrix is stored internally (row or column major)

            If the order makes a difference, that means that colsum and rowsum should also have different speeds:

            Code:
            cls
            clear all
            set more off
            
            
            mata:
                timer_clear()
                x = J(10000,10000,0)
                timer_on(1)
                y = colsum(x)
                timer_off(1)
                timer_on(2)
                z = rowsum(x)
                timer_off(2)
                timer()
            end
            (I can't test it right now as my cpu usage is at 100%, but suspect it will matter).

            Also interestingly, we arrived at the same solution as numpy's

            Comment


            • #7
              Code:
              . mata:
              ------------------------------------------------- mata (type end to exit) -----------------------------------------------------------------------------------------------------------------------------------------------------------
              :     timer_clear()
              
              :     x = J(10000,10000,0)
              
              :     timer_on(1)
              
              :     y = colsum(x)
              
              :     timer_off(1)
              
              :     timer_on(2)
              
              :     z = rowsum(x)
              
              :     timer_off(2)
              
              :     timer()
              
              -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
              timer report
                1.       1.98 /        1 =     1.975
                2.       .145 /        1 =      .145
              -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
              
              : end

              Comment

              Working...
              X