• google cloud platform 을 이용하여 분산 학습및 예측을 해보자

    2019. 2. 22. 19:47

    by. 위지원

    아래 가이드를 참조했다.



    이 예제 코드는 Cloud ML  Engine에서 비동기식 업데이트로 데이터 병렬 교육을 한다.


    신경망 모델은 예제를 통해 학습 할 수있는 함수입니다.


    전체 모델이 모든 worker 노드와 공유된다. 각 노드는 mini batch 처리와 동일한 방식으로 학습 데이터 세트의 일부분을 독립적으로 기울기 벡터를 계산하고 이는 parameter server에 수집이 된다.


    아래 그림과 같이 ML engine에서 분산 된 교육 작업을 수행하고 datalab으로 예측을 실행한다. 텐서보드로 시각화까지 해보장

    튜토리얼에서 사용되는 아키텍처.


    앞부분은 설정과 데이터를 다운로드 받는 부분이므로 쉽다. 앞부분의 설정이 완료 되었으면 ML 엔진에 훈련을 요청한다.


    upsidejiwon@cloudshell:~/cloudml-dist-mnist-example/cloudml-dist-mnist-example/cloudml-dist-mnist-example (tensorproject)$ JOB_NAME="job_$(date +%Y%m%d_%H%M%S)"
    upsidejiwon@cloudshell:~/cloudml-dist-mnist-example/cloudml-dist-mnist-example/cloudml-dist-mnist-example (tensorproject)$ gcloud ml-engine jobs submit training ${JOB_NAME} \
    >     --package-path trainer \
    >     --module-name trainer.task \
    >     --staging-bucket gs://${BUCKET} \
    >     --job-dir gs://${BUCKET}/${JOB_NAME} \
    >     --runtime-version 1.2 \
    >     --region us-central1 \
    >     --config config/config.yaml \
    >     -- \
    >     --data_dir gs://${BUCKET}/data \
    >     --output_dir gs://${BUCKET}/${JOB_NAME} \
    >     --train_steps 10000


    그럼 cloud ML의 작업에서 학습이 진행되는 것을 확인할 수 있다.



    로그를 보면 손실을 계산하고 있는것을 알 수 있다. master와 worker가 같이 일하고있는것도 확인이 되쥬~?



    모든 학습이 끝났다... 정확가 거의 99프로다.




    모델이 생겼다.. 순조롭군 좋다..



    tensorboard --port 8080 --logdir gs://${BUCKET}/${JOB_NAME}

    upsidejiwon@cloudshell:~/cloudml-dist-mnist-example/cloudml-dist-mnist-example/cloudml-dist-mnist-example (tensorproject)$ gsutil ls gs://${BUCKET}/${JOB_NAME}/export/Servo | tail -1
    gs://tensorproject-ml/job_20190222_191727/export/Servo/1550831466/


    아래 명령어를 이용해서 시각화해서 그 결과를 한번 구경해보자

    tensorboard --port 8080 --logdir gs://${BUCKET}/${JOB_NAME}

    손실계산과 정확도가 점점 낮아지고 높아지는 것을 확인할 수 있다.



    뭔지 다 정확힌 모르지만 텐서플로우 그래프도 확인 가능하다 ^^



    가이드를 따라 모델의 배포버전을 생성하고 예측을 해보자.


    MODEL_NAME=MNIST
    gcloud ml-engine models create --regions us-central1 ${MODEL_NAME}
    VERSION_NAME=v1
    ORIGIN=$(gsutil ls gs://${BUCKET}/${JOB_NAME}/export/Servo | tail -1)
    gcloud ml-engine versions create \
        --origin ${ORIGIN} \
        --model ${MODEL_NAME} \
        ${VERSION_NAME}
    gcloud ml-engine versions set-default --model ${MODEL_NAME} ${VERSION_NAME}

    upsidejiwon@cloudshell:~/cloudml-dist-mnist-example/cloudml-dist-mnist-example/cloudml-dist-mnist-example (tensorproject)$ gcloud ml-engine versions set-default --model ${MODEL_NAME} ${VERSION_NAME}
    createTime: '2019-02-22T10:37:56Z'
    deploymentUri: gs://tensorproject-ml/job_20190222_191727/export/Servo/1550831466/
    etag: YTZ1Vdw41UI=
    framework: TENSORFLOW
    isDefault: true
    machineType: mls1-c1-m2
    name: projects/tensorproject/models/MNIST/versions/v1
    pythonVersion: '2.7'
    runtimeVersion: '1.0'
    state: READY


    그리고 나서 아래 두가지 명령어를 입력해보자. 그럼 예측에 대한 결과를 모델에 기반하여 알려준다!

    pts/make_request.py
    gcloud ml-engine predict --model ${MODEL_NAME} --json-instances request.json

    upsidejiwon@cloudshell:~/cloudml-dist-mnist-example/cloudml-dist-mnist-example/cloudml-dist-mnist-example (tensorproject)$ ./scripts/make_request.py
    WARNING:tensorflow:From ./scripts/make_request.py:20: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
    WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please write your own downloading logic.
    WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use tf.data to implement this functionality.
    Extracting /tmp/data/train-images-idx3-ubyte.gz
    WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use tf.data to implement this functionality.
    Extracting /tmp/data/train-labels-idx1-ubyte.gz
    WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use tf.one_hot on tensors.
    Extracting /tmp/data/t10k-images-idx3-ubyte.gz
    Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
    WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: __init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
    upsidejiwon@cloudshell:~/cloudml-dist-mnist-example/cloudml-dist-mnist-example/cloudml-dist-mnist-example (tensorproject)$ gcloud ml-engine predict --model ${MODEL_NAME} --json-instances request.json
    CLASSES  PROBABILITIES
    7        [2.252253195210896e-16, 2.62512125752401e-13, 8.145778634407151e-15, 1.587594191851914e-14, 8.453198400591375e-14, 2.0736690037765357e-17, 3.7328870973973786e-20, 1.0, 4.481962608530619e-15, 1.2157129144520218e-13]
    2        [1.4830203037830003e-12, 1.703580954418287e-14, 1.0, 1.1505197786656809e-21, 8.483649065638577e-17, 1.5231707229913098e-25, 2.2731596624150268e-15, 1.065206680576563e-20, 1.5161558883606955e-18, 5.817562282061209e-19]
    1        [3.5372128403404757e-13, 1.0, 4.405297063551569e-14, 2.3924037960591847e-16, 8.533004164368307e-12, 1.1116443052530678e-12, 1.5194136638967126e-13, 1.8340590954594657e-14, 1.1466857324782254e-10, 3.1506440793978296e-13]
    0        [1.0, 4.7600195540817986e-14, 5.79419216599486e-13, 2.4598521190886263e-14, 4.2716803266276346e-15, 3.350107802388598e-14, 1.1670789890061428e-09, 1.1326154597535165e-10, 7.379767912214785e-13, 5.6593601333032595e-11]
    4        [2.899808756832989e-13, 1.2658522008379691e-09, 1.9342487680984455e-12, 7.295591336457276e-15, 0.9999998807907104, 6.389001827589169e-11, 2.947848909418127e-11, 8.994636979675619e-11, 2.6609903613916686e-08, 7.596319306912847e-08]
    1        [1.1502748623395953e-12, 1.0, 7.058507830960334e-14, 1.3835189991501446e-16, 3.58129532973539e-11, 5.926553735616236e-13, 2.285072901297286e-13, 1.3154395921057688e-13, 6.75269284933222e-10, 2.9094391826300914e-12]
    4        [6.293827146505168e-22, 3.5090891303823923e-10, 7.165834213683602e-18, 8.867360010697118e-22, 1.0, 4.282644744613368e-15, 1.3543444051198483e-16, 1.7946323958004296e-16, 1.922331627213225e-09, 7.735121137358858e-14]
    9        [6.9990743242602255e-15, 1.5780783624297356e-10, 1.3390492707709978e-11, 2.861841080954719e-13, 2.964413425843304e-08, 9.113979102665093e-13, 3.8855367995179163e-16, 1.2258579297052402e-13, 1.3869858150883374e-08, 1.0]
    5        [1.5480527508771047e-06, 4.7186694018819253e-07, 7.422308878624051e-10, 1.8296990589306006e-09, 1.5254257959895767e-05, 0.7882888913154602, 0.21087637543678284, 5.837003658193396e-10, 0.0007937122718431056, 2.3742037228657864e-05]
    9        [8.550430837266001e-15, 3.3524023401038106e-12, 5.6722454932643575e-15, 4.3564240756466255e-12, 1.2762179721903522e-06, 1.9413998186834647e-11, 4.2428393841692177e-16, 5.374457032303326e-06, 7.872608165371275e-08, 0.9999932050704956]