ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • LSTM
    AI/DeepLearning 2022. 10. 4. 17:04

     

    # Input_shape

    - 3개의 인자를 받는다
    - batch, window_size, time_step(samples, features, time_step)

    # LSTM 파라미더 보기 

    # return_sequences=False 인 경우, 마지막아웃풋만을 반환 
    
    x = tf.random.uniform(shape=(32,25, 1))
    lstm = LSTM(20)  # LSTM에서 output을 20으로 정했따.. 
    output_a = lstm(x)
    output_a[0]  # 한개의 데이터만 보기 => 25개의 widnow size중 하나만을, 한개의 데이터마다 20개의 출력값을 가지고 있따
    
    <tf.Tensor: shape=(20,), dtype=float32, numpy=
    array([ 0.05721829,  0.07174001,  0.22409686, -0.17256463,  0.00261475,
            0.12204304,  0.09205984, -0.08590516, -0.08659009,  0.09036262,
           -0.02182901,  0.02992326, -0.08432557,  0.04029592,  0.06246312,
           -0.14243276, -0.01183044, -0.04864559, -0.02895623,  0.09458034],
          dtype=float32)>
          
          
    output_a[:5]  # 5개의 데이터 보기
    
    <tf.Tensor: shape=(5, 20), dtype=float32, numpy=
    array([[ 0.05721829,  0.07174001,  0.22409686, -0.17256463,  0.00261475,
             0.12204304,  0.09205984, -0.08590516, -0.08659009,  0.09036262,
            -0.02182901,  0.02992326, -0.08432557,  0.04029592,  0.06246312,
            -0.14243276, -0.01183044, -0.04864559, -0.02895623,  0.09458034],
           [ 0.07147642,  0.08401469,  0.21837758, -0.16784014,  0.01412732,
             0.12846987,  0.0966792 , -0.09791196, -0.106291  ,  0.0992037 ,
            -0.01261229,  0.02136466, -0.08759472,  0.05354249,  0.05720092,
            -0.1265583 , -0.03504875, -0.04306702, -0.01262065,  0.09892861],
           [ 0.07568941,  0.09950773,  0.24885145, -0.19341308,  0.01248754,
             0.12906061,  0.10355023, -0.09850564, -0.11633494,  0.1043708 ,
            -0.02373959,  0.02830943, -0.09332076,  0.05472516,  0.06546076,
            -0.14443362, -0.03385764, -0.05037692, -0.01409153,  0.10387589],
           [ 0.0450749 ,  0.05851956,  0.1807287 , -0.14339864, -0.00047827,
             0.09657934,  0.07322896, -0.06382511, -0.06938691,  0.07246287,
            -0.01816005,  0.02340475, -0.06645556,  0.03164771,  0.05163307,
            -0.11315813, -0.00657244, -0.03922507, -0.02633894,  0.07822459],
           [ 0.08895023,  0.10636114,  0.26272014, -0.20020506,  0.02113218,
             0.14875767,  0.11687247, -0.11904698, -0.13261984,  0.11954913,
            -0.01556885,  0.02586792, -0.10309242,  0.06621417,  0.06660333,
            -0.14812891, -0.04828766, -0.05037866, -0.0081014 ,  0.11517317]],
          dtype=float32)>
          
          
      output_a.shape 
      # TensorShape([32, 20])

     

    #  return_sequences=True 인 경우, 시퀀스 전체를 반환 
    
    x = tf.random.uniform(shape=(32, 25, 1))
    lstm = LSTM(20, return_sequences=True)  
    output_b = lstm(x)
    output_b[0] # 하나의 데이터당 25개만의 timestep, 20개의 output 형태 
    
    <tf.Tensor: shape=(25, 20), dtype=float32, numpy=
    array([[ 1.22041907e-03,  1.08073093e-03,  8.16172641e-03,
            -1.09033659e-03, -7.75709236e-03, -4.45257965e-03,
            -7.05770263e-03, -3.86817241e-03, -3.60156852e-03,
            -8.36325064e-03, -4.37901879e-04,  7.19021680e-03,
             2.89761391e-03,  3.51775531e-03,  2.54748203e-03,
            -5.02266968e-03,  6.76784478e-03, -3.13248864e-04,
             3.46898660e-03, -6.79486245e-03],
           [ 6.83384389e-03,  6.60734344e-03,  5.25702275e-02,
            -6.86181849e-03, -5.19570559e-02, -2.63409279e-02,
            -4.54780161e-02, -2.48211399e-02, -2.35709939e-02,
            -5.14949858e-02, -3.18379840e-03,  4.14868407e-02,
             1.93437412e-02,  2.13689711e-02,  1.44859375e-02,
            -3.10363676e-02,  4.05552909e-02, -2.13485747e-03,
             2.31895689e-02, -4.34374735e-02],
           [ 2.64318869e-03,  2.73951166e-03,  3.74350846e-02,
            -6.85099512e-03, -3.60343494e-02, -1.83704551e-02,
            -3.30195092e-02, -2.50079501e-02, -1.46113979e-02,
            -4.49629761e-02, -4.26608138e-03,  4.14540432e-02,
             1.02235982e-02,  1.95899662e-02,  8.36186111e-03,
            -1.82736572e-02,  3.34229842e-02, -3.40385200e-03,
             1.63388103e-02, -3.05405334e-02],
           [ 3.41353775e-03,  3.49241099e-03,  4.85874712e-02,
            -9.24341008e-03, -4.83330712e-02, -2.43268684e-02,
            -4.49422523e-02, -3.37673388e-02, -1.95454750e-02,
            -5.66645823e-02, -5.04006585e-03,  5.57189956e-02,
             1.31931957e-02,  2.56509986e-02,  9.92337801e-03,
            -2.17178166e-02,  4.21627462e-02, -4.28221095e-03,
             2.23887283e-02, -4.20755595e-02],
           [ 6.36718143e-03,  6.20215898e-03,  7.79856592e-02,
            -1.38687231e-02, -7.92944878e-02, -3.78010683e-02,
            -7.15809911e-02, -5.11923730e-02, -3.25491540e-02,
            -8.51761028e-02, -7.20608467e-03,  8.25531557e-02,
             2.29664128e-02,  3.86666209e-02,  1.55879809e-02,
            -3.56691666e-02,  6.36452883e-02, -5.78409480e-03,
             3.65474150e-02, -6.83839396e-02],
           [ 6.58501312e-03,  6.10389281e-03,  9.30031538e-02,
            -1.76880378e-02, -9.49545130e-02, -4.36609499e-02,
            -8.51249024e-02, -6.54550344e-02, -3.78353633e-02,
            -1.04561567e-01, -9.64318216e-03,  1.05324104e-01,
             2.56802961e-02,  4.81261536e-02,  1.61721352e-02,
            -3.94593328e-02,  7.69257694e-02, -7.45421881e-03,
             4.36575450e-02, -8.20019692e-02],
           [ 8.79346672e-03,  8.04074202e-03,  1.24090523e-01,
            -2.29405742e-02, -1.30217105e-01, -5.65030649e-02,
            -1.13842063e-01, -8.68583322e-02, -5.18542118e-02,
            -1.33037329e-01, -1.29528446e-02,  1.33609071e-01,
             3.59776504e-02,  6.24329709e-02,  1.98900737e-02,
            -5.20869605e-02,  9.83205736e-02, -9.26873833e-03,
             6.00150786e-02, -1.11215256e-01],
           [ 4.25276533e-03,  2.31852708e-03,  1.06315054e-01,
            -2.48167608e-02, -1.08320080e-01, -4.48672809e-02,
            -9.69268158e-02, -8.95075053e-02, -4.07872610e-02,
            -1.31802410e-01, -1.41094821e-02,  1.44455329e-01,
             2.33815089e-02,  6.11553304e-02,  1.05468221e-02,
            -3.41583490e-02,  9.23954099e-02, -1.02395341e-02,
             4.90785986e-02, -9.28103775e-02],
           [ 1.94055936e-03, -1.06389925e-03,  9.43412259e-02,
            -2.55963113e-02, -9.68415514e-02, -3.81637067e-02,
            -9.04318243e-02, -8.98977518e-02, -3.61867212e-02,
            -1.23099312e-01, -1.24254311e-02,  1.47078961e-01,
             1.67466719e-02,  5.82742356e-02,  3.97967733e-03,
            -2.05834210e-02,  8.37260559e-02, -9.76885483e-03,
             4.48194779e-02, -8.55826661e-02],
           [ 2.24990025e-03, -2.43173796e-03,  9.51721817e-02,
            -2.70136166e-02, -9.76843387e-02, -3.69616374e-02,
            -9.33499932e-02, -9.31358710e-02, -3.76992561e-02,
            -1.23418570e-01, -1.06840106e-02,  1.53554857e-01,
             1.59106795e-02,  5.86968660e-02,  1.85569003e-03,
            -1.66401602e-02,  8.24358016e-02, -8.92007910e-03,
             4.60076369e-02, -8.88233557e-02],
           [ 1.79997343e-03, -4.39695781e-03,  9.09262523e-02,
            -2.80475542e-02, -9.17011127e-02, -3.28456350e-02,
            -9.02425721e-02, -9.36200693e-02, -3.62905599e-02,
            -1.21647790e-01, -8.89268890e-03,  1.58259869e-01,
             1.31055415e-02,  5.74093871e-02, -1.80191756e-03,
            -1.02467928e-02,  7.87418783e-02, -7.92820193e-03,
             4.38204966e-02, -8.60917494e-02],
           [ 3.90441227e-03, -4.02972940e-03,  1.04116127e-01,
            -2.98827905e-02, -1.05265543e-01, -3.76197360e-02,
            -1.02743298e-01, -1.00426473e-01, -4.34223078e-02,
            -1.30558074e-01, -7.82017037e-03,  1.67508215e-01,
             1.78381521e-02,  6.16655946e-02,  2.66354182e-04,
            -1.55443233e-02,  8.55344981e-02, -7.24374317e-03,
             5.08292951e-02, -9.98448282e-02],
           [ 6.44344231e-03, -2.88807065e-03,  1.25938118e-01,
            -3.28454711e-02, -1.28291756e-01, -4.60961312e-02,
            -1.21513069e-01, -1.12526111e-01, -5.38518988e-02,
            -1.48200184e-01, -8.61966610e-03,  1.82508767e-01,
             2.57622711e-02,  6.99679703e-02,  3.88499117e-03,
            -2.57333070e-02,  9.93457660e-02, -7.35437311e-03,
             6.17506355e-02, -1.20283119e-01],
           [ 5.04594902e-03, -5.23457397e-03,  1.22888982e-01,
            -3.48064266e-02, -1.22478664e-01, -4.26000096e-02,
            -1.17248967e-01, -1.16331384e-01, -5.09537831e-02,
            -1.53150603e-01, -9.23671294e-03,  1.93253055e-01,
             2.21021939e-02,  7.12714195e-02, -6.56124728e-04,
            -1.94427576e-02,  1.00354224e-01, -7.42991781e-03,
             5.88355660e-02, -1.16296351e-01],
           [ 5.41581027e-03, -6.00686157e-03,  1.29261807e-01,
            -3.63603905e-02, -1.29939824e-01, -4.47281525e-02,
            -1.24537595e-01, -1.22348957e-01, -5.44689633e-02,
            -1.57789409e-01, -9.09790862e-03,  2.00326994e-01,
             2.37208977e-02,  7.42398277e-02, -1.18806632e-03,
            -1.97740663e-02,  1.04066551e-01, -7.31136557e-03,
             6.31471649e-02, -1.24459215e-01],
           [-2.57400126e-04, -1.18501959e-02,  9.26809236e-02,
            -3.57057527e-02, -8.68967846e-02, -2.39204224e-02,
            -9.13387313e-02, -1.09331898e-01, -3.67650315e-02,
            -1.37304977e-01, -6.73473580e-03,  1.93733320e-01,
             7.66025251e-03,  6.16027787e-02, -1.58549082e-02,
             6.59017824e-03,  8.15925971e-02, -5.94445458e-03,
             4.33759093e-02, -8.85856822e-02],
           [ 3.23198154e-03, -1.07235806e-02,  1.07802868e-01,
            -3.55938226e-02, -1.04435399e-01, -3.14487591e-02,
            -1.08411603e-01, -1.13045387e-01, -4.65876795e-02,
            -1.39135063e-01, -3.95943038e-03,  1.91951707e-01,
             1.47466324e-02,  6.41621947e-02, -9.56800859e-03,
            -3.27272550e-03,  8.63615349e-02, -4.64590034e-03,
             5.36931232e-02, -1.07374951e-01],
           [ 3.31505388e-03, -1.13566061e-02,  1.07053772e-01,
            -3.65490206e-02, -1.00939609e-01, -3.00666299e-02,
            -1.06285036e-01, -1.13345139e-01, -4.57491502e-02,
            -1.41881958e-01, -3.19361081e-03,  1.97307259e-01,
             1.39580239e-02,  6.40551895e-02, -1.14321960e-02,
            -1.11951272e-03,  8.65841210e-02, -3.91329359e-03,
             5.23650125e-02, -1.06184982e-01],
           [ 4.35030647e-03, -1.11768991e-02,  1.13360159e-01,
            -3.71991880e-02, -1.07145652e-01, -3.28061096e-02,
            -1.12346187e-01, -1.16103351e-01, -4.90793623e-02,
            -1.45463750e-01, -2.42656516e-03,  2.00686812e-01,
             1.64688751e-02,  6.59547225e-02, -1.01681706e-02,
            -4.00207238e-03,  8.99741054e-02, -3.50161595e-03,
             5.61498739e-02, -1.13541752e-01],
           [ 1.45510782e-03, -1.39279040e-02,  9.32101160e-02,
            -3.66853625e-02, -8.30817446e-02, -2.18354855e-02,
            -9.38606709e-02, -1.08023494e-01, -3.91366631e-02,
            -1.33681491e-01, -9.07015288e-04,  1.97088316e-01,
             8.12417269e-03,  5.87929673e-02, -1.83447339e-02,
             9.99613944e-03,  7.73777291e-02, -2.50909873e-03,
             4.54900190e-02, -9.44078267e-02],
           [-3.19188926e-04, -1.59332100e-02,  7.45606720e-02,
            -3.46226990e-02, -6.25108033e-02, -1.22703481e-02,
            -7.90117010e-02, -9.71155986e-02, -3.16912010e-02,
            -1.15297124e-01,  2.24257121e-03,  1.84684724e-01,
             2.12084292e-03,  4.98964079e-02, -2.39737760e-02,
             2.05928925e-02,  6.19066618e-02, -7.51930580e-04,
             3.76420207e-02, -7.95149654e-02],
           [ 5.40334312e-03, -1.22013399e-02,  1.06692798e-01,
            -3.52777764e-02, -9.70118642e-02, -2.77729463e-02,
            -1.07845083e-01, -1.06429689e-01, -4.84065749e-02,
            -1.32182404e-01,  3.17935040e-03,  1.88173130e-01,
             1.61816832e-02,  5.84569201e-02, -1.09731378e-02,
            -1.95135316e-03,  7.96920732e-02, -5.03240270e-04,
             5.46936914e-02, -1.10440038e-01],
           [ 1.72154058e-03, -1.47834513e-02,  8.36473703e-02,
            -3.50130908e-02, -6.87270388e-02, -1.55026997e-02,
            -8.40454176e-02, -9.74724144e-02, -3.54899466e-02,
            -1.22417346e-01,  3.34686390e-03,  1.86520100e-01,
             6.38154894e-03,  5.11072427e-02, -2.08550021e-02,
             1.36252707e-02,  6.71976060e-02,  1.38333809e-04,
             4.09751609e-02, -8.59601945e-02],
           [-2.63828173e-04, -1.64522398e-02,  6.48955256e-02,
            -3.24326009e-02, -4.90309261e-02, -6.73476839e-03,
            -6.93339407e-02, -8.61879066e-02, -2.78159827e-02,
            -1.02744557e-01,  5.88867487e-03,  1.71673462e-01,
             6.40149927e-04,  4.23096158e-02, -2.59049498e-02,
             2.33084094e-02,  5.17392345e-02,  1.62587618e-03,
             3.35025750e-02, -7.13706464e-02],
           [-4.76299669e-04, -1.67267099e-02,  5.40735275e-02,
            -3.00311949e-02, -3.70791033e-02, -1.70586980e-03,
            -6.01031967e-02, -7.65181333e-02, -2.37342175e-02,
            -8.88250843e-02,  8.29489622e-03,  1.58588126e-01,
            -1.57908269e-03,  3.56241241e-02, -2.75230967e-02,
             2.68288683e-02,  4.12583649e-02,  3.22037865e-03,
             2.91442778e-02, -6.29316196e-02]], dtype=float32)>
             
             
    output_b.shape  # 32는 batch의 갯수,25는 timestep, 20은 lstm output
    TensorShape([32, 25, 20])
     # return_sequence=False 를 Dense에 넘겨줄 경우
    
    dense = Dense(10)  # (32, 20) => (32,10)으로 변함
    dense(output_a).shape
    
    TensorShape([32, 10])
    
    
    # return_sequence=True 를 Dense에 넘겨줄 경우
    
    dense = Dense(10)  # (32, 25, 20) => (32,25,10)으로 변환 
    dense(output_b).shape
    
    TensorShape([32, 25, 10])

     

     

     

     

     

     

     

     

     

    LSTM의 입력으로는 3차원의 데이터가 필요(samples, time_steps, features)

    1. sampels => 데이터의 크기 
    2. times steps => 과거 몇 개의 데이터를 볼 것인가를 나타내며, 네트워크에 사용할 시간 단위 
    3. X의 변수 갯(특성갯수)

    LSTM 입출력 구조는 다양하다

    - 일대다(One-to-Many), 다대일(Many-to-One), 다대다(Many-to-Many)

     

    다대일

     

     

    다대다

     

    1) MANY TO ONE 

    - Feature= 5개는 5가지의 컬럼을 뜻한다.(MANY)

    - 현재와 과거 2일 동안의 가격을 사용하여 미래 1일 동안(TO ONE)의 가격을 예측. Time steps는 현재+ 과거2일 해서 3인 것이다 

    - Sample은 5가 된다. 왜냐하면 Feature X Time Step // Time Step 이기 때문

     

     

    2) MANY TO MANY 

    -  과거 X일 동안의 가격을 사용해서 미래 Y일 동안 가격을 예측한다 

     

     

    입력 데이터 정규화하기 

    - LSTM과 GRU는 입력 데이터들의 단위가 다른데, 정규화를 하지 않으면 모델이 데이터를 학습하는데 좋지 않을 수 있다. 

    - 그래서 입력 데이터를 3차원으로 분리하기 전 정규화 과정을 거치는 것이 일반적

    - 학습데이터만 스케일 변환에 사용한다는 것 

     

    LSTM 모델을 구축하는데 필요한 요소 3가지 

    - 아키텍처, 컴파일링, 피팅

     

    1) 아키텍처 

    - LSTM의 모델을 어떤 구조로 쌓을 것인지 정하는 부분

    model = Sequential()
    model.add(LSTM()).. 
    model.add(Dense())..

     

    2) 컴파일링

    -  모델을 학습시키기전에  모델 학습 환경에 대한 설정을 해 주는 부분 

    optimizer=.. 정규화 방법을 지정하는 부분 
    loss=... 모델을 최적화 시키는데 사용되는 목적함수(=손실함수)를 지정하는 부분
    metric=.. 분류 문제에서 어떤 것을 기준으로 삼을지 정하는 부분

     

    3) 피팅

    -  모델을 학습시키는 부분 

    epoch.. 
    batch_size.. 
    iteration..

     

    참고: 

    https://abstractask.tistory.com/105

     

    앙상블

    concatenate 2개 이상의 모델 합치기 데이터 구성 x 값 y값이 미리 유추되지 않도록 데이터 컬럼을 약간 섞음 x 2개 : 300개씩의 데이터 y 1개 : 100개의 데이터 #1. 데이터 import numpy as np x1 = np.array([ra..

    abstractask.tistory.com

    https://data-analysis-expertise.tistory.com/67

     

    [LSTM/GRU] 주식가격 예측 모델 구현하기

    LSTM과 GRU 코드 예시를 찾아 공부하던 중 설명이 가장 자세했던 글을 소개합니다. 설명이 자세해서 ARIMA, RNN, LSTM, GRU 의 차이와 특징을 잘 알 수 있었습니다. 그 중 LSTM과 GRU 를 제가 이해한 바와 함

    data-analysis-expertise.tistory.com

     

     

    https://tykimos.github.io/2017/08/17/Text_Input_Multiclass_Classification_Model_Recipe/

    https://zereight.tistory.com/227

    'AI > DeepLearning' 카테고리의 다른 글

    어테션메커니즘(AttentionMechanism)  (0) 2022.10.07
    콜백함수(EarlyStopping)  (0) 2022.10.05
    인공신경망(Artificial Neural Network)  (0) 2022.10.03
    레이어들(Layers)  (0) 2022.09.30
    퍼셉트론(Perceptron)  (0) 2022.09.29

    댓글

Designed by Tistory.