#!/bin/bash

# Commands:
# connect <kernel_gateway_name> [--ssh-only]
# run-command <command> <args...>

# SageMaker Studio Kernel Gateway name is usually the same as the hostname,
# e. g. sagemaker-data-science-ml-m5-large-1234567890abcdef0

COMMAND=$1

if [[ "$COMMAND" == "connect" ]]; then

  SM_STUDIO_KGW_NAME="$2"
  OPTIONS="$3"

  INSTANCE_ID=$(python <<EOF
import sagemaker; from sagemaker_ssh_helper.manager import SSMManager;
import logging; logging.basicConfig(level=logging.INFO);
print(SSMManager().get_studio_kgw_instance_ids("$SM_STUDIO_KGW_NAME", timeout_in_sec=300)[0])
EOF
  )

  # replace with your JetBrains License Server host, or leave it as is if you don't use one
  JB_LICENSE_SERVER_HOST="jetbrains-license-server.example.com"

  if [ -f ~/.sm-jb-license-server ]; then
      echo "sm-local-ssh-ide: ~/.sm-jb-license-server file with PyCharm license server host is already configured, skipping override"
      JB_LICENSE_SERVER_HOST="$(cat ~/.sm-jb-license-server)"
  else
      echo "sm-local-ssh-ide: Saving PyCharm License server host into ~/.sm-jb-license-server"
      echo "$JB_LICENSE_SERVER_HOST" > ~/.sm-jb-license-server
  fi


  if [[ "$OPTIONS" == "--ssh-only" ]]; then
    sm-local-start-ssh "$INSTANCE_ID" \
        -L localhost:10022:localhost:22
  else
    sm-local-start-ssh "$INSTANCE_ID" \
        -L localhost:10022:localhost:22 \
        -L localhost:5901:localhost:5901 \
        -L localhost:8889:localhost:8889 \
        -R 127.0.0.1:443:"$JB_LICENSE_SERVER_HOST":443
  fi


elif [[ "$COMMAND" == "run-command" ]]; then

  shift
  ARGS=$*

  # shellcheck disable=SC2086
  ssh -i ~/.ssh/sagemaker-ssh-gw -p 10022 root@localhost \
    -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
    $ARGS

else
    echo "Deprecated warning: sm-local-ssh-ide <kernel_gateway_name> is deprecated and will be removed in future versions, use sm-local-ssh-ide connect <kernel_gateway_name> instead"
    echo "Re-trying with the 'connect' argument:"
    echo "sm-local-ssh-ide connect $*"
    # shellcheck disable=SC2048
    $0 connect $*
fi