-
LSTMAI/DeepLearning 2022. 10. 4. 17:04
# Input_shape
- 3개의 인자를 받는다 - batch, window_size, time_step(samples, features, time_step)
# LSTM 파라미더 보기
# return_sequences=False 인 경우, 마지막아웃풋만을 반환 x = tf.random.uniform(shape=(32,25, 1)) lstm = LSTM(20) # LSTM에서 output을 20으로 정했따.. output_a = lstm(x) output_a[0] # 한개의 데이터만 보기 => 25개의 widnow size중 하나만을, 한개의 데이터마다 20개의 출력값을 가지고 있따 <tf.Tensor: shape=(20,), dtype=float32, numpy= array([ 0.05721829, 0.07174001, 0.22409686, -0.17256463, 0.00261475, 0.12204304, 0.09205984, -0.08590516, -0.08659009, 0.09036262, -0.02182901, 0.02992326, -0.08432557, 0.04029592, 0.06246312, -0.14243276, -0.01183044, -0.04864559, -0.02895623, 0.09458034], dtype=float32)> output_a[:5] # 5개의 데이터 보기 <tf.Tensor: shape=(5, 20), dtype=float32, numpy= array([[ 0.05721829, 0.07174001, 0.22409686, -0.17256463, 0.00261475, 0.12204304, 0.09205984, -0.08590516, -0.08659009, 0.09036262, -0.02182901, 0.02992326, -0.08432557, 0.04029592, 0.06246312, -0.14243276, -0.01183044, -0.04864559, -0.02895623, 0.09458034], [ 0.07147642, 0.08401469, 0.21837758, -0.16784014, 0.01412732, 0.12846987, 0.0966792 , -0.09791196, -0.106291 , 0.0992037 , -0.01261229, 0.02136466, -0.08759472, 0.05354249, 0.05720092, -0.1265583 , -0.03504875, -0.04306702, -0.01262065, 0.09892861], [ 0.07568941, 0.09950773, 0.24885145, -0.19341308, 0.01248754, 0.12906061, 0.10355023, -0.09850564, -0.11633494, 0.1043708 , -0.02373959, 0.02830943, -0.09332076, 0.05472516, 0.06546076, -0.14443362, -0.03385764, -0.05037692, -0.01409153, 0.10387589], [ 0.0450749 , 0.05851956, 0.1807287 , -0.14339864, -0.00047827, 0.09657934, 0.07322896, -0.06382511, -0.06938691, 0.07246287, -0.01816005, 0.02340475, -0.06645556, 0.03164771, 0.05163307, -0.11315813, -0.00657244, -0.03922507, -0.02633894, 0.07822459], [ 0.08895023, 0.10636114, 0.26272014, -0.20020506, 0.02113218, 0.14875767, 0.11687247, -0.11904698, -0.13261984, 0.11954913, -0.01556885, 0.02586792, -0.10309242, 0.06621417, 0.06660333, -0.14812891, -0.04828766, -0.05037866, -0.0081014 , 0.11517317]], dtype=float32)> output_a.shape # TensorShape([32, 20])
# return_sequences=True 인 경우, 시퀀스 전체를 반환 x = tf.random.uniform(shape=(32, 25, 1)) lstm = LSTM(20, return_sequences=True) output_b = lstm(x) output_b[0] # 하나의 데이터당 25개만의 timestep, 20개의 output 형태 <tf.Tensor: shape=(25, 20), dtype=float32, numpy= array([[ 1.22041907e-03, 1.08073093e-03, 8.16172641e-03, -1.09033659e-03, -7.75709236e-03, -4.45257965e-03, -7.05770263e-03, -3.86817241e-03, -3.60156852e-03, -8.36325064e-03, -4.37901879e-04, 7.19021680e-03, 2.89761391e-03, 3.51775531e-03, 2.54748203e-03, -5.02266968e-03, 6.76784478e-03, -3.13248864e-04, 3.46898660e-03, -6.79486245e-03], [ 6.83384389e-03, 6.60734344e-03, 5.25702275e-02, -6.86181849e-03, -5.19570559e-02, -2.63409279e-02, -4.54780161e-02, -2.48211399e-02, -2.35709939e-02, -5.14949858e-02, -3.18379840e-03, 4.14868407e-02, 1.93437412e-02, 2.13689711e-02, 1.44859375e-02, -3.10363676e-02, 4.05552909e-02, -2.13485747e-03, 2.31895689e-02, -4.34374735e-02], [ 2.64318869e-03, 2.73951166e-03, 3.74350846e-02, -6.85099512e-03, -3.60343494e-02, -1.83704551e-02, -3.30195092e-02, -2.50079501e-02, -1.46113979e-02, -4.49629761e-02, -4.26608138e-03, 4.14540432e-02, 1.02235982e-02, 1.95899662e-02, 8.36186111e-03, -1.82736572e-02, 3.34229842e-02, -3.40385200e-03, 1.63388103e-02, -3.05405334e-02], [ 3.41353775e-03, 3.49241099e-03, 4.85874712e-02, -9.24341008e-03, -4.83330712e-02, -2.43268684e-02, -4.49422523e-02, -3.37673388e-02, -1.95454750e-02, -5.66645823e-02, -5.04006585e-03, 5.57189956e-02, 1.31931957e-02, 2.56509986e-02, 9.92337801e-03, -2.17178166e-02, 4.21627462e-02, -4.28221095e-03, 2.23887283e-02, -4.20755595e-02], [ 6.36718143e-03, 6.20215898e-03, 7.79856592e-02, -1.38687231e-02, -7.92944878e-02, -3.78010683e-02, -7.15809911e-02, -5.11923730e-02, -3.25491540e-02, -8.51761028e-02, -7.20608467e-03, 8.25531557e-02, 2.29664128e-02, 3.86666209e-02, 1.55879809e-02, -3.56691666e-02, 6.36452883e-02, -5.78409480e-03, 3.65474150e-02, -6.83839396e-02], [ 6.58501312e-03, 6.10389281e-03, 9.30031538e-02, -1.76880378e-02, -9.49545130e-02, -4.36609499e-02, -8.51249024e-02, -6.54550344e-02, -3.78353633e-02, -1.04561567e-01, -9.64318216e-03, 1.05324104e-01, 2.56802961e-02, 4.81261536e-02, 1.61721352e-02, -3.94593328e-02, 7.69257694e-02, -7.45421881e-03, 4.36575450e-02, -8.20019692e-02], [ 8.79346672e-03, 8.04074202e-03, 1.24090523e-01, -2.29405742e-02, -1.30217105e-01, -5.65030649e-02, -1.13842063e-01, -8.68583322e-02, -5.18542118e-02, -1.33037329e-01, -1.29528446e-02, 1.33609071e-01, 3.59776504e-02, 6.24329709e-02, 1.98900737e-02, -5.20869605e-02, 9.83205736e-02, -9.26873833e-03, 6.00150786e-02, -1.11215256e-01], [ 4.25276533e-03, 2.31852708e-03, 1.06315054e-01, -2.48167608e-02, -1.08320080e-01, -4.48672809e-02, -9.69268158e-02, -8.95075053e-02, -4.07872610e-02, -1.31802410e-01, -1.41094821e-02, 1.44455329e-01, 2.33815089e-02, 6.11553304e-02, 1.05468221e-02, -3.41583490e-02, 9.23954099e-02, -1.02395341e-02, 4.90785986e-02, -9.28103775e-02], [ 1.94055936e-03, -1.06389925e-03, 9.43412259e-02, -2.55963113e-02, -9.68415514e-02, -3.81637067e-02, -9.04318243e-02, -8.98977518e-02, -3.61867212e-02, -1.23099312e-01, -1.24254311e-02, 1.47078961e-01, 1.67466719e-02, 5.82742356e-02, 3.97967733e-03, -2.05834210e-02, 8.37260559e-02, -9.76885483e-03, 4.48194779e-02, -8.55826661e-02], [ 2.24990025e-03, -2.43173796e-03, 9.51721817e-02, -2.70136166e-02, -9.76843387e-02, -3.69616374e-02, -9.33499932e-02, -9.31358710e-02, -3.76992561e-02, -1.23418570e-01, -1.06840106e-02, 1.53554857e-01, 1.59106795e-02, 5.86968660e-02, 1.85569003e-03, -1.66401602e-02, 8.24358016e-02, -8.92007910e-03, 4.60076369e-02, -8.88233557e-02], [ 1.79997343e-03, -4.39695781e-03, 9.09262523e-02, -2.80475542e-02, -9.17011127e-02, -3.28456350e-02, -9.02425721e-02, -9.36200693e-02, -3.62905599e-02, -1.21647790e-01, -8.89268890e-03, 1.58259869e-01, 1.31055415e-02, 5.74093871e-02, -1.80191756e-03, -1.02467928e-02, 7.87418783e-02, -7.92820193e-03, 4.38204966e-02, -8.60917494e-02], [ 3.90441227e-03, -4.02972940e-03, 1.04116127e-01, -2.98827905e-02, -1.05265543e-01, -3.76197360e-02, -1.02743298e-01, -1.00426473e-01, -4.34223078e-02, -1.30558074e-01, -7.82017037e-03, 1.67508215e-01, 1.78381521e-02, 6.16655946e-02, 2.66354182e-04, -1.55443233e-02, 8.55344981e-02, -7.24374317e-03, 5.08292951e-02, -9.98448282e-02], [ 6.44344231e-03, -2.88807065e-03, 1.25938118e-01, -3.28454711e-02, -1.28291756e-01, -4.60961312e-02, -1.21513069e-01, -1.12526111e-01, -5.38518988e-02, -1.48200184e-01, -8.61966610e-03, 1.82508767e-01, 2.57622711e-02, 6.99679703e-02, 3.88499117e-03, -2.57333070e-02, 9.93457660e-02, -7.35437311e-03, 6.17506355e-02, -1.20283119e-01], [ 5.04594902e-03, -5.23457397e-03, 1.22888982e-01, -3.48064266e-02, -1.22478664e-01, -4.26000096e-02, -1.17248967e-01, -1.16331384e-01, -5.09537831e-02, -1.53150603e-01, -9.23671294e-03, 1.93253055e-01, 2.21021939e-02, 7.12714195e-02, -6.56124728e-04, -1.94427576e-02, 1.00354224e-01, -7.42991781e-03, 5.88355660e-02, -1.16296351e-01], [ 5.41581027e-03, -6.00686157e-03, 1.29261807e-01, -3.63603905e-02, -1.29939824e-01, -4.47281525e-02, -1.24537595e-01, -1.22348957e-01, -5.44689633e-02, -1.57789409e-01, -9.09790862e-03, 2.00326994e-01, 2.37208977e-02, 7.42398277e-02, -1.18806632e-03, -1.97740663e-02, 1.04066551e-01, -7.31136557e-03, 6.31471649e-02, -1.24459215e-01], [-2.57400126e-04, -1.18501959e-02, 9.26809236e-02, -3.57057527e-02, -8.68967846e-02, -2.39204224e-02, -9.13387313e-02, -1.09331898e-01, -3.67650315e-02, -1.37304977e-01, -6.73473580e-03, 1.93733320e-01, 7.66025251e-03, 6.16027787e-02, -1.58549082e-02, 6.59017824e-03, 8.15925971e-02, -5.94445458e-03, 4.33759093e-02, -8.85856822e-02], [ 3.23198154e-03, -1.07235806e-02, 1.07802868e-01, -3.55938226e-02, -1.04435399e-01, -3.14487591e-02, -1.08411603e-01, -1.13045387e-01, -4.65876795e-02, -1.39135063e-01, -3.95943038e-03, 1.91951707e-01, 1.47466324e-02, 6.41621947e-02, -9.56800859e-03, -3.27272550e-03, 8.63615349e-02, -4.64590034e-03, 5.36931232e-02, -1.07374951e-01], [ 3.31505388e-03, -1.13566061e-02, 1.07053772e-01, -3.65490206e-02, -1.00939609e-01, -3.00666299e-02, -1.06285036e-01, -1.13345139e-01, -4.57491502e-02, -1.41881958e-01, -3.19361081e-03, 1.97307259e-01, 1.39580239e-02, 6.40551895e-02, -1.14321960e-02, -1.11951272e-03, 8.65841210e-02, -3.91329359e-03, 5.23650125e-02, -1.06184982e-01], [ 4.35030647e-03, -1.11768991e-02, 1.13360159e-01, -3.71991880e-02, -1.07145652e-01, -3.28061096e-02, -1.12346187e-01, -1.16103351e-01, -4.90793623e-02, -1.45463750e-01, -2.42656516e-03, 2.00686812e-01, 1.64688751e-02, 6.59547225e-02, -1.01681706e-02, -4.00207238e-03, 8.99741054e-02, -3.50161595e-03, 5.61498739e-02, -1.13541752e-01], [ 1.45510782e-03, -1.39279040e-02, 9.32101160e-02, -3.66853625e-02, -8.30817446e-02, -2.18354855e-02, -9.38606709e-02, -1.08023494e-01, -3.91366631e-02, -1.33681491e-01, -9.07015288e-04, 1.97088316e-01, 8.12417269e-03, 5.87929673e-02, -1.83447339e-02, 9.99613944e-03, 7.73777291e-02, -2.50909873e-03, 4.54900190e-02, -9.44078267e-02], [-3.19188926e-04, -1.59332100e-02, 7.45606720e-02, -3.46226990e-02, -6.25108033e-02, -1.22703481e-02, -7.90117010e-02, -9.71155986e-02, -3.16912010e-02, -1.15297124e-01, 2.24257121e-03, 1.84684724e-01, 2.12084292e-03, 4.98964079e-02, -2.39737760e-02, 2.05928925e-02, 6.19066618e-02, -7.51930580e-04, 3.76420207e-02, -7.95149654e-02], [ 5.40334312e-03, -1.22013399e-02, 1.06692798e-01, -3.52777764e-02, -9.70118642e-02, -2.77729463e-02, -1.07845083e-01, -1.06429689e-01, -4.84065749e-02, -1.32182404e-01, 3.17935040e-03, 1.88173130e-01, 1.61816832e-02, 5.84569201e-02, -1.09731378e-02, -1.95135316e-03, 7.96920732e-02, -5.03240270e-04, 5.46936914e-02, -1.10440038e-01], [ 1.72154058e-03, -1.47834513e-02, 8.36473703e-02, -3.50130908e-02, -6.87270388e-02, -1.55026997e-02, -8.40454176e-02, -9.74724144e-02, -3.54899466e-02, -1.22417346e-01, 3.34686390e-03, 1.86520100e-01, 6.38154894e-03, 5.11072427e-02, -2.08550021e-02, 1.36252707e-02, 6.71976060e-02, 1.38333809e-04, 4.09751609e-02, -8.59601945e-02], [-2.63828173e-04, -1.64522398e-02, 6.48955256e-02, -3.24326009e-02, -4.90309261e-02, -6.73476839e-03, -6.93339407e-02, -8.61879066e-02, -2.78159827e-02, -1.02744557e-01, 5.88867487e-03, 1.71673462e-01, 6.40149927e-04, 4.23096158e-02, -2.59049498e-02, 2.33084094e-02, 5.17392345e-02, 1.62587618e-03, 3.35025750e-02, -7.13706464e-02], [-4.76299669e-04, -1.67267099e-02, 5.40735275e-02, -3.00311949e-02, -3.70791033e-02, -1.70586980e-03, -6.01031967e-02, -7.65181333e-02, -2.37342175e-02, -8.88250843e-02, 8.29489622e-03, 1.58588126e-01, -1.57908269e-03, 3.56241241e-02, -2.75230967e-02, 2.68288683e-02, 4.12583649e-02, 3.22037865e-03, 2.91442778e-02, -6.29316196e-02]], dtype=float32)> output_b.shape # 32는 batch의 갯수,25는 timestep, 20은 lstm output TensorShape([32, 25, 20])
# return_sequence=False 를 Dense에 넘겨줄 경우 dense = Dense(10) # (32, 20) => (32,10)으로 변함 dense(output_a).shape TensorShape([32, 10]) # return_sequence=True 를 Dense에 넘겨줄 경우 dense = Dense(10) # (32, 25, 20) => (32,25,10)으로 변환 dense(output_b).shape TensorShape([32, 25, 10])
LSTM의 입력으로는 3차원의 데이터가 필요(samples, time_steps, features)
1. sampels => 데이터의 크기 2. times steps => 과거 몇 개의 데이터를 볼 것인가를 나타내며, 네트워크에 사용할 시간 단위 3. X의 변수 갯(특성갯수)
LSTM 입출력 구조는 다양하다
- 일대다(One-to-Many), 다대일(Many-to-One), 다대다(Many-to-Many)
다대일 다대다 1) MANY TO ONE
- Feature= 5개는 5가지의 컬럼을 뜻한다.(MANY)
- 현재와 과거 2일 동안의 가격을 사용하여 미래 1일 동안(TO ONE)의 가격을 예측. Time steps는 현재+ 과거2일 해서 3인 것이다
- Sample은 5가 된다. 왜냐하면 Feature X Time Step // Time Step 이기 때문
2) MANY TO MANY
- 과거 X일 동안의 가격을 사용해서 미래 Y일 동안 가격을 예측한다
입력 데이터 정규화하기
- LSTM과 GRU는 입력 데이터들의 단위가 다른데, 정규화를 하지 않으면 모델이 데이터를 학습하는데 좋지 않을 수 있다.
- 그래서 입력 데이터를 3차원으로 분리하기 전 정규화 과정을 거치는 것이 일반적
- 학습데이터만 스케일 변환에 사용한다는 것
LSTM 모델을 구축하는데 필요한 요소 3가지
- 아키텍처, 컴파일링, 피팅
1) 아키텍처
- LSTM의 모델을 어떤 구조로 쌓을 것인지 정하는 부분
model = Sequential() model.add(LSTM()).. model.add(Dense())..
2) 컴파일링
- 모델을 학습시키기전에 모델 학습 환경에 대한 설정을 해 주는 부분
optimizer=.. 정규화 방법을 지정하는 부분 loss=... 모델을 최적화 시키는데 사용되는 목적함수(=손실함수)를 지정하는 부분 metric=.. 분류 문제에서 어떤 것을 기준으로 삼을지 정하는 부분
3) 피팅
- 모델을 학습시키는 부분
epoch.. batch_size.. iteration..
참고:
https://abstractask.tistory.com/105
앙상블
concatenate 2개 이상의 모델 합치기 데이터 구성 x 값 y값이 미리 유추되지 않도록 데이터 컬럼을 약간 섞음 x 2개 : 300개씩의 데이터 y 1개 : 100개의 데이터 #1. 데이터 import numpy as np x1 = np.array([ra..
abstractask.tistory.com
https://data-analysis-expertise.tistory.com/67
[LSTM/GRU] 주식가격 예측 모델 구현하기
LSTM과 GRU 코드 예시를 찾아 공부하던 중 설명이 가장 자세했던 글을 소개합니다. 설명이 자세해서 ARIMA, RNN, LSTM, GRU 의 차이와 특징을 잘 알 수 있었습니다. 그 중 LSTM과 GRU 를 제가 이해한 바와 함
data-analysis-expertise.tistory.com
https://tykimos.github.io/2017/08/17/Text_Input_Multiclass_Classification_Model_Recipe/
'AI > DeepLearning' 카테고리의 다른 글
어테션메커니즘(AttentionMechanism) (0) 2022.10.07 콜백함수(EarlyStopping) (0) 2022.10.05 인공신경망(Artificial Neural Network) (0) 2022.10.03 레이어들(Layers) (0) 2022.09.30 퍼셉트론(Perceptron) (0) 2022.09.29