诸葛温侯 9 месяцев назад
Сommit
ae219bd0b4
100 измененных файлов с 13173 добавлено и 0 удалено
  1. 128 0
      .circleci/config.yml
  2. 21 0
      .github/CODEOWNERS
  3. 3 0
      .github/ISSUE_TEMPLATE.md
  4. 43 0
      .github/ISSUE_TEMPLATE/bug_report.md
  5. 15 0
      .github/ISSUE_TEMPLATE/documentation.md
  6. 24 0
      .github/ISSUE_TEMPLATE/feature_request.md
  7. 33 0
      .github/ISSUE_TEMPLATE/how-to-question.md
  8. 16 0
      .github/PULL_REQUEST_TEMPLATE.md
  9. 30 0
      .github/stale.yml
  10. 81 0
      .github/workflows/build.yml
  11. 161 0
      .github/workflows/release.yml
  12. 141 0
      .gitignore
  13. 4 0
      .gitmodules
  14. 40 0
      .pre-commit-config.yaml
  15. 77 0
      CODE_OF_CONDUCT.md
  16. 82 0
      CONTRIBUTING.md
  17. 21 0
      LICENSE
  18. 1 0
      MANIFEST.in
  19. 242 0
      README.md
  20. 13 0
      RELEASE.md
  21. 20 0
      docs/Makefile
  22. 85 0
      docs/command_line_tools.rst
  23. 98 0
      docs/conf.py
  24. 31 0
      docs/criterions.rst
  25. 58 0
      docs/data.rst
  26. 2 0
      docs/docutils.conf
  27. BIN
      docs/fairseq.gif
  28. BIN
      docs/fairseq_logo.png
  29. 216 0
      docs/getting_started.rst
  30. 284 0
      docs/hydra_integration.md
  31. 49 0
      docs/index.rst
  32. 34 0
      docs/lr_scheduler.rst
  33. 36 0
      docs/make.bat
  34. 104 0
      docs/models.rst
  35. 9 0
      docs/modules.rst
  36. 38 0
      docs/optim.rst
  37. 74 0
      docs/overview.rst
  38. 61 0
      docs/tasks.rst
  39. 415 0
      docs/tutorial_classifying_names.rst
  40. 518 0
      docs/tutorial_simple_lstm.rst
  41. 2 0
      examples/.gitignore
  42. 139 0
      examples/MMPT/.gitignore
  43. 41 0
      examples/MMPT/CONFIG.md
  44. 34 0
      examples/MMPT/DATASET.md
  45. 166 0
      examples/MMPT/README.md
  46. 41 0
      examples/MMPT/endtask.md
  47. 148 0
      examples/MMPT/locallaunch.py
  48. 12 0
      examples/MMPT/mmpt/__init__.py
  49. 10 0
      examples/MMPT/mmpt/datasets/__init__.py
  50. 57 0
      examples/MMPT/mmpt/datasets/fairseqmmdataset.py
  51. 111 0
      examples/MMPT/mmpt/datasets/mmdataset.py
  52. 13 0
      examples/MMPT/mmpt/evaluators/__init__.py
  53. 54 0
      examples/MMPT/mmpt/evaluators/evaluator.py
  54. 313 0
      examples/MMPT/mmpt/evaluators/metric.py
  55. 595 0
      examples/MMPT/mmpt/evaluators/predictor.py
  56. 16 0
      examples/MMPT/mmpt/losses/__init__.py
  57. 63 0
      examples/MMPT/mmpt/losses/fairseqmmloss.py
  58. 87 0
      examples/MMPT/mmpt/losses/loss.py
  59. 156 0
      examples/MMPT/mmpt/losses/nce.py
  60. 17 0
      examples/MMPT/mmpt/models/__init__.py
  61. 51 0
      examples/MMPT/mmpt/models/fairseqmmmodel.py
  62. 926 0
      examples/MMPT/mmpt/models/mmfusion.py
  63. 999 0
      examples/MMPT/mmpt/models/mmfusionnlg.py
  64. 734 0
      examples/MMPT/mmpt/models/transformermodel.py
  65. 10 0
      examples/MMPT/mmpt/modules/__init__.py
  66. 145 0
      examples/MMPT/mmpt/modules/mm.py
  67. 429 0
      examples/MMPT/mmpt/modules/retri.py
  68. 246 0
      examples/MMPT/mmpt/modules/vectorpool.py
  69. 23 0
      examples/MMPT/mmpt/processors/__init__.py
  70. 242 0
      examples/MMPT/mmpt/processors/dedupprocessor.py
  71. 848 0
      examples/MMPT/mmpt/processors/dsprocessor.py
  72. 887 0
      examples/MMPT/mmpt/processors/how2processor.py
  73. 100 0
      examples/MMPT/mmpt/processors/how2retriprocessor.py
  74. 336 0
      examples/MMPT/mmpt/processors/models/s3dg.py
  75. 274 0
      examples/MMPT/mmpt/processors/processor.py
  76. 22 0
      examples/MMPT/mmpt/tasks/__init__.py
  77. 104 0
      examples/MMPT/mmpt/tasks/fairseqmmtask.py
  78. 27 0
      examples/MMPT/mmpt/tasks/milncetask.py
  79. 253 0
      examples/MMPT/mmpt/tasks/retritask.py
  80. 184 0
      examples/MMPT/mmpt/tasks/task.py
  81. 27 0
      examples/MMPT/mmpt/tasks/vlmtask.py
  82. 68 0
      examples/MMPT/mmpt/utils/__init__.py
  83. 81 0
      examples/MMPT/mmpt/utils/load_config.py
  84. 46 0
      examples/MMPT/mmpt/utils/shardedtensor.py
  85. 117 0
      examples/MMPT/mmpt_cli/localjob.py
  86. 113 0
      examples/MMPT/mmpt_cli/predict.py
  87. 29 0
      examples/MMPT/pretraining.md
  88. 59 0
      examples/MMPT/projects/mfmmlm.yaml
  89. 19 0
      examples/MMPT/projects/mtm/mmfusionmtm.yaml
  90. 8 0
      examples/MMPT/projects/mtm/vlm.yaml
  91. 47 0
      examples/MMPT/projects/mtm/vlm/coin.yaml
  92. 53 0
      examples/MMPT/projects/mtm/vlm/crosstask.yaml
  93. 55 0
      examples/MMPT/projects/mtm/vlm/how2.yaml
  94. 31 0
      examples/MMPT/projects/mtm/vlm/test_coin.yaml
  95. 38 0
      examples/MMPT/projects/mtm/vlm/test_crosstask.yaml
  96. 38 0
      examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml
  97. 29 0
      examples/MMPT/projects/mtm/vlm/test_vtt.yaml
  98. 29 0
      examples/MMPT/projects/mtm/vlm/test_vttqa.yaml
  99. 31 0
      examples/MMPT/projects/mtm/vlm/test_youcook.yaml
  100. 32 0
      examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml

+ 128 - 0
.circleci/config.yml

@@ -0,0 +1,128 @@
+# Use 2.1 for orbs
+version: 2.1
+
+# -------------------------------------------------------------------------------------
+# Environments to run the jobs in
+# -------------------------------------------------------------------------------------
+gpu: &gpu
+  environment:
+    CUDA_VERSION: "11.2"
+  machine:
+    image: ubuntu-2004-cuda-11.2:202103-01
+  resource_class: gpu.nvidia.medium.multi
+
+
+# -------------------------------------------------------------------------------------
+# Re-usable commands
+# -------------------------------------------------------------------------------------
+cache_key: &cache_key cache-key-{{ .Environment.CIRCLE_JOB }}-{{ checksum ".circleci/config.yml" }}-{{ checksum "setup.py"}}
+
+install_dep_pt1_10: &install_dep_pt1_10
+  - run:
+      name: Install Pytorch Dependencies
+      command: |
+        source activate fairseq
+        pip install --upgrade setuptools
+        pip install torch==1.10.1+cu111 torchaudio==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
+        python -c 'import torch; print("Torch version:", torch.__version__)'
+
+install_dep_pt1_12: &install_dep_pt1_12
+  - run:
+      name: Install Pytorch Dependencies
+      command: |
+        source activate fairseq
+        pip install --upgrade setuptools
+        pip install torch==1.12.1+cu116 torchaudio==0.12.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html
+        python -c 'import torch; print("Torch version:", torch.__version__)'
+
+install_repo: &install_repo
+  - run:
+      name: Install Repository
+      command: |
+        source activate fairseq
+        python -m pip install fairscale
+        python -m pip install -e '.[dev,docs]'
+        python -c 'import torch; print("Torch version:", torch.__version__)'
+
+run_unittests: &run_unittests
+  - run:
+      name: Run Unit Tests
+      command: |
+        source activate fairseq
+        pytest tests/gpu/test_binaries_gpu.py
+
+check_nvidia_driver: &check_nvidia_driver
+  - run:
+      name: Check NVIDIA Driver
+      working_directory: ~/
+      command: |
+        pyenv versions
+        nvidia-smi
+
+create_conda_env: &create_conda_env
+  - run:
+      name: Install and Create Conda Environment
+      command: |
+        curl -o ~/miniconda.sh -O  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
+        chmod +x ~/miniconda.sh
+        bash ~/miniconda.sh -b -p $HOME/miniconda
+        rm ~/miniconda.sh
+        echo 'export PATH=$HOME/miniconda/bin:$PATH' >> $BASH_ENV
+        source $BASH_ENV
+        if [ ! -d ~/miniconda/envs/fairseq ]
+        then
+          conda create -y -n fairseq python=3.8
+        fi
+        source activate fairseq
+        python --version
+        pip install --upgrade pip
+# -------------------------------------------------------------------------------------
+# Jobs to run
+# -------------------------------------------------------------------------------------
+
+jobs:
+
+  gpu_tests_pt1_10:
+    <<: *gpu
+
+    working_directory: ~/fairseq-py
+
+    steps:
+      - checkout
+      - <<: *check_nvidia_driver
+      - <<: *create_conda_env
+      - restore_cache:
+          key: *cache_key
+      - <<: *install_dep_pt1_10
+      - save_cache:
+          paths:
+            - ~/miniconda/
+          key: *cache_key
+      - <<: *install_repo
+      - <<: *run_unittests
+
+  gpu_tests_pt1_12:
+    <<: *gpu
+
+    working_directory: ~/fairseq-py
+
+    steps:
+      - checkout
+      - <<: *check_nvidia_driver
+      - <<: *create_conda_env
+      - restore_cache:
+          key: *cache_key
+      - <<: *install_dep_pt1_12
+      - save_cache:
+          paths:
+            - ~/miniconda/
+          key: *cache_key
+      - <<: *install_repo
+      - <<: *run_unittests
+
+workflows:
+  version: 2
+  build:
+    jobs:
+      - gpu_tests_pt1_12
+      - gpu_tests_pt1_10

+ 21 - 0
.github/CODEOWNERS

@@ -0,0 +1,21 @@
+# Setting up CODEOWNERS for UST related codebase
+# Documentation for open sourced models relevant to UST
+examples/speech_to_text     @kahne @sravyapopuri388 @jmp84
+examples/speech_to_speech   @an918tw @sravyapopuri388 @jmp84
+examples/speech_synthesis   @kahne @jmp84
+examples/simultaneous_translation   @kahne @jmp84
+examples/speech_text_joint_to_text  @yuntang @jmp84
+
+# Speech related models relevant to UST
+fairseq/models/speech_to_speech @sravyapopuri388 @jmp84
+fairseq/models/speech_to_text   @kahne @sravyapopuri388 @jmp84
+fairseq/models/text_to_speech   @kahne @jmp84
+
+# CONFORMER IMPLEMENTATION
+fairseq/modules/conformer_layer.py @sravyapopuri388 @jmp84
+fairseq/modules/espnet_multihead_attention.py @sravyapopuri388 @jmp84
+fairseq/modules/rotary_positional_embedding.py @sravyapopuri388 @jmp84
+fairseq/modules/positional_encoding.py @sravyapopuri388 @jmp84
+
+# Machine Translation/NLLB
+fairseq/tasks/translation.py @gwenzek

+ 3 - 0
.github/ISSUE_TEMPLATE.md

@@ -0,0 +1,3 @@
+## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
+
+Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.

+ 43 - 0
.github/ISSUE_TEMPLATE/bug_report.md

@@ -0,0 +1,43 @@
+---
+name: 🐛 Bug Report
+about: Submit a bug report to help us improve
+labels: 'bug, needs triage'
+---
+
+## 🐛 Bug
+
+<!-- A clear and concise description of what the bug is. -->
+
+### To Reproduce
+
+Steps to reproduce the behavior (**always include the command you ran**):
+
+1. Run cmd '....'
+2. See error
+
+<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
+
+
+#### Code sample
+<!-- Ideally attach a minimal code sample to reproduce the decried issue.
+Minimal means having the shortest code but still preserving the bug. -->
+
+### Expected behavior
+
+<!-- A clear and concise description of what you expected to happen. -->
+
+### Environment
+
+ - fairseq Version (e.g., 1.0 or main):
+ - PyTorch Version (e.g., 1.0)
+ - OS (e.g., Linux):
+ - How you installed fairseq (`pip`, source):
+ - Build command you used (if compiling from source):
+ - Python version:
+ - CUDA/cuDNN version:
+ - GPU models and configuration:
+ - Any other relevant information:
+
+### Additional context
+
+<!-- Add any other context about the problem here. -->

+ 15 - 0
.github/ISSUE_TEMPLATE/documentation.md

@@ -0,0 +1,15 @@
+---
+name: 📚 Documentation/Typos
+about: Report an issue related to documentation or a typo
+labels: 'documentation, needs triage'
+---
+
+## 📚 Documentation
+
+For typos and doc fixes, please go ahead and:
+
+1. Create an issue.
+2. Fix the typo.
+3. Submit a PR.
+
+Thanks!

+ 24 - 0
.github/ISSUE_TEMPLATE/feature_request.md

@@ -0,0 +1,24 @@
+---
+name: 🚀 Feature Request
+about: Submit a proposal/request for a new feature
+labels: 'enhancement, help wanted, needs triage'
+---
+
+## 🚀 Feature Request
+<!-- A clear and concise description of the feature proposal -->
+
+### Motivation
+
+<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
+
+### Pitch
+
+<!-- A clear and concise description of what you want to happen. -->
+
+### Alternatives
+
+<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
+
+### Additional context
+
+<!-- Add any other context or screenshots about the feature request here. -->

+ 33 - 0
.github/ISSUE_TEMPLATE/how-to-question.md

@@ -0,0 +1,33 @@
+---
+name: ❓ Questions/Help
+about: If you have questions, please first search existing issues and docs
+labels: 'question, needs triage'
+---
+
+## ❓ Questions and Help
+
+### Before asking:
+1. search the issues.
+2. search the docs.
+
+<!-- If you still can't find what you need: -->
+
+#### What is your question?
+
+#### Code
+
+<!-- Please paste a code snippet if your question requires it! -->
+
+#### What have you tried?
+
+#### What's your environment?
+
+ - fairseq Version (e.g., 1.0 or main):
+ - PyTorch Version (e.g., 1.0)
+ - OS (e.g., Linux):
+ - How you installed fairseq (`pip`, source):
+ - Build command you used (if compiling from source):
+ - Python version:
+ - CUDA/cuDNN version:
+ - GPU models and configuration:
+ - Any other relevant information:

+ 16 - 0
.github/PULL_REQUEST_TEMPLATE.md

@@ -0,0 +1,16 @@
+# Before submitting
+
+- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
+- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
+- [ ] Did you make sure to update the docs?
+- [ ] Did you write any new necessary tests?
+
+## What does this PR do?
+Fixes # (issue).
+
+## PR review
+Anyone in the community is free to review the PR once the tests have passed.
+If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
+
+## Did you have fun?
+Make sure you had fun coding 🙃

+ 30 - 0
.github/stale.yml

@@ -0,0 +1,30 @@
+# Configuration for probot-stale - https://github.com/probot/stale
+# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
+# Number of days of inactivity before an issue becomes stale
+daysUntilStale: 90
+# Number of days of inactivity before a stale issue is closed
+daysUntilClose: 7
+# Issues with these labels will never be considered stale
+exemptLabels:
+  - bug
+# Label to use when marking an issue as stale
+staleLabel: stale
+issues:
+  # Comment to post when marking an issue as stale.
+  markComment: >
+    This issue has been automatically marked as stale.
+    **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
+    We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
+  # Comment to post when closing a stale issue.
+  closeComment: >
+    Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
+pulls:
+  # Comment to post when marking a pull request as stale.
+  markComment: >
+    This pull request has been automatically marked as stale.
+    **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
+    We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
+  # Comment to post when closing a stale pull request.
+  closeComment: >
+    Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
+

+ 81 - 0
.github/workflows/build.yml

@@ -0,0 +1,81 @@
+name: build
+
+on:
+  # Trigger the workflow on push to main or any pull request
+  push:
+    branches:
+      - main
+  pull_request:
+
+jobs:
+  build:
+
+    strategy:
+      max-parallel: 4
+      matrix:
+        platform: [ubuntu-latest, macos-latest]
+        python-version: [3.8, 3.9]
+
+    runs-on: ${{ matrix.platform }}
+
+    steps:
+    - uses: actions/checkout@v2
+
+    - name: Set up Python ${{ matrix.python-version }}
+      uses: actions/setup-python@v2
+      with:
+        python-version: ${{ matrix.python-version }}
+
+    - name: Conditionally install pytorch
+      if: matrix.platform == 'windows-latest'
+      run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
+
+    - name: Install locally
+      run: |
+        python -m pip install --upgrade pip
+        git submodule update --init --recursive
+        python -m pip install .
+
+    - name: Check installation
+      working-directory: /tmp
+      run: python $GITHUB_WORKSPACE/scripts/check_installation.py
+
+    - name: Install optional test requirements
+      run: |
+        python -m pip install '.[dev,docs]'
+        python -m pip install iopath transformers pyarrow
+        python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
+        python -m pip install pygit2 pgzip
+        
+    - name: Install xformers for Macos
+      if: matrix.platform == 'macos-latest'
+      run: |
+        brew install llvm libomp
+        CC=/usr/local/opt/llvm/bin/clang CXX=clang++ pip install git+https://github.com/facebookresearch/xformers.git@main
+
+    - name: Install xformers for non-MacOS
+      if: matrix.platform != 'macos-latest'
+      run: |
+        python -m pip install --progress-bar off git+https://github.com/facebookresearch/xformers.git@main
+
+    - name: Lint with black
+      run: black --check --diff .
+
+    - name: Lint with flake8
+      run: |
+        # stop the build if there are Python syntax errors or undefined names
+        flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+        # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+        flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
+
+    - name: Build doc
+      run: make singlehtml
+      working-directory: docs/
+
+    - name: Run tests
+      # When installing in non-editable mode, the .so files will be generated in 'site-packages/fairseq'.
+      # But by default, pytest import machinery will load local fairseq, and won't see the .so.
+      # Use --import-mode=append to favorize the 'site-packages/fairseq'.
+      # https://docs.pytest.org/en/7.1.x/explanation/pythonpath.html
+      run: pytest --import-mode=append -vvv tests/
+

+ 161 - 0
.github/workflows/release.yml

@@ -0,0 +1,161 @@
+name: Fairseq Release
+
+on:
+  workflow_dispatch:
+    inputs:
+      name:
+        description: 'Release Type'
+        default: 'patch'
+        required: true
+
+jobs:
+
+  get_next_version:
+    runs-on: ubuntu-latest
+    steps:
+      - name: checkout-repo-content
+        uses: actions/checkout@v2
+
+      - name: setup-python
+        uses: actions/setup-python@v2
+        with:
+          python-version: 3.8
+
+      - name: get next version and tag
+        id: get-next-version-and-tag
+        run: |
+          output=$(python3 release_utils.py --release-type ${{ github.event.inputs.name }}) 
+          echo $output
+          new_version=$(echo $output | awk '{print $1}')
+          new_tag=$(echo $output | awk '{print $2}')
+          echo "new version is $new_version"
+          echo "new tag is $new_tag"
+          echo ::set-output name=version::$new_version
+          echo ::set-output name=tag::$new_tag
+          echo ::set-output name=branch_name::$new_version-release
+          echo "NEW_TAG=$new_tag" >> $GITHUB_ENV
+          echo "NEW_BRANCH=$new_version-release" >> $GITHUB_ENV
+
+
+      # update the version number in version.txt
+      - name: update version
+        id: update-version
+        run : |
+          echo "current folder = $PWD"
+          echo "current branch = $(git branch --show-current)"
+          output=$(python3 release_utils.py --release-type ${{ github.event.inputs.name }} --update-version)
+
+      - name: add and commit
+        uses: EndBug/add-and-commit@v9
+        with:
+          author_name: ${{ secrets.AUTHOR_NAME }}
+          author_email: ${{ secrets.AUTHOR_EMAIL }}
+
+          # TODO: change this to main once shipit is disabled.
+          new_branch: '${{ env.NEW_BRANCH }}'
+          default_author: github_actor
+          message: '${{ env.NEW_TAG }} release'
+          pathspec_error_handling: exitAtEnd
+
+          # Arguments for the git pull command. Use NO-PULL to avoid the action pulling at all.
+          # pull: 'NO-PULL'
+          tag: '${{ env.NEW_TAG }}'
+
+    outputs:
+      new_version: ${{ steps.get-next-version-and-tag.outputs.version }}
+      new_tag: ${{ steps.get-next-version-and-tag.outputs.tag }}
+      branch_name: ${{ steps.get-next-version-and-tag.outputs.branch_name }}
+
+  create_sdist:
+    runs-on: ubuntu-latest
+    name: Create Source Distribution
+    needs: get_next_version
+    steps:
+      - uses: actions/checkout@v3
+        with:
+          ref: ${{ needs.get_next_version.outputs.branch_name }}
+
+      - name: Install Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: '3.8'
+
+      - name: Upgrade pip
+        run: |
+          python3 -m pip install --upgrade pip
+
+      - name: Create Source Distribution
+        run: |
+          python3 -m pip install setuptools wheel twine torch
+          python3 setup.py sdist
+ 
+      - uses: actions/upload-artifact@v2
+        with:
+          path: dist/*.tar.gz
+
+  build_wheels:
+    name: Build wheels on ${{ matrix.os }}
+    runs-on: ${{ matrix.os }}
+    needs: get_next_version
+    strategy:
+      matrix:
+        os: [ubuntu-latest, macos-latest]
+
+    steps:
+      - uses: actions/checkout@v3
+        with:
+          ref: ${{ needs.get_next_version.outputs.branch_name }}
+
+      - name: Install Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: '3.8'
+
+      - name: Upgrade pip
+        run: |
+          python3 -m pip install --upgrade pip
+
+      - name: Install cibuildwheel
+        run: |
+          python3 -m pip install cibuildwheel
+
+      - name: Build wheels for CPython
+        run: |
+          python3 -m cibuildwheel --output-dir dist
+        env:
+          CIBW_BUILD: "cp38-*64"
+          CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
+          CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
+          # Install system library
+          CIBW_BEFORE_BUILD_LINUX: (yum install -y libffi-devel || apt-get install -y libffi-devel || apk add --update --no-cache libffi-devel || true) && (yum install -y libc6 || apt-get install -y libc6 || apk add --update --no-cache libc6 || true)
+          CIBW_ENVIRONMENT: "PIP_ONLY_BINARY=numpy"
+          CIBW_SKIP: "*musllinux*"
+
+      - uses: actions/upload-artifact@v2
+        with:
+          path: dist
+
+  upload:
+    name: Upload to PyPi and create release
+    runs-on: ubuntu-latest
+    needs: [build_wheels, create_sdist, get_next_version]
+    steps:
+      - uses: actions/download-artifact@v2
+        with:
+          name: artifact
+          path: dist
+
+      # build the PyPI package and upload it
+      - name: upload
+        env:
+          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
+          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+        run: |
+          pip install setuptools wheel twine
+          python3 -m twine upload --repository pypi dist/*
+
+      # create the release on github
+      - name: create release on github
+        uses: ncipollo/release-action@v1
+        with:
+          tag: '${{ needs.get_next_version.outputs.new_tag }}'

+ 141 - 0
.gitignore

@@ -0,0 +1,141 @@
+# JetBrains PyCharm IDE
+.idea/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# macOS dir files
+.DS_Store
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# Checkpoints
+checkpoints
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# dotenv
+.env
+
+# virtualenv
+.venv
+venv/
+ENV/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# Generated files
+/fairseq/temporal_convolution_tbc
+/fairseq/modules/*_layer/*_forward.cu
+/fairseq/modules/*_layer/*_backward.cu
+/fairseq/version.py
+
+# data
+data-bin/
+
+# reranking
+/examples/reranking/rerank_data
+
+# Cython-generated C++ source files
+/fairseq/data/data_utils_fast.cpp
+/fairseq/data/token_block_utils_fast.cpp
+
+# VSCODE
+.vscode/ftp-sync.json
+.vscode/settings.json
+
+# Experimental Folder
+experimental/*
+
+# Weights and Biases logs
+wandb/
+
+# Hydra artifacts
+nohup.out
+multirun
+outputs

+ 4 - 0
.gitmodules

@@ -0,0 +1,4 @@
+[submodule "fairseq/model_parallel/megatron"]
+    path = fairseq/model_parallel/megatron
+    url = https://github.com/ngoyal2707/Megatron-LM
+    branch = fairseq

+ 40 - 0
.pre-commit-config.yaml

@@ -0,0 +1,40 @@
+exclude: 'build|stubs'
+
+default_language_version:
+    python: python3
+
+repos:
+-   repo: https://github.com/pre-commit/pre-commit-hooks
+    rev: v4.1.0
+    hooks:
+    -   id: trailing-whitespace
+    -   id: check-ast
+    -   id: check-merge-conflict
+    -   id: no-commit-to-branch
+        args: ['--branch=master']
+    -   id: check-added-large-files
+        args: ['--maxkb=500']
+    -   id: end-of-file-fixer
+
+-   repo: https://github.com/ambv/black
+    rev: 22.3.0
+    hooks:
+    - id: black
+      language_version: python3.8
+
+-   repo: https://gitlab.com/pycqa/flake8
+    rev: 3.9.2
+    hooks:
+    -   id: flake8
+        args: [
+            # only error for syntax errors and undefined names
+            "--select=E9,F63,F7,F82",
+        ]
+
+-   repo: https://github.com/pycqa/isort
+    rev: 5.10.1
+    hooks:
+    -   id: isort
+        exclude: README.md
+        additional_dependencies: [toml]
+        args: ["--profile", "black"]

+ 77 - 0
CODE_OF_CONDUCT.md

@@ -0,0 +1,77 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+  advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+  address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+  professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at <conduct@pytorch.org>. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
+

+ 82 - 0
CONTRIBUTING.md

@@ -0,0 +1,82 @@
+# Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Pull Requests
+We actively welcome your pull requests.
+
+1. Fork the repo and create your branch from `main`.
+2. If you've added code that should be tested, add tests.
+3. If you've changed APIs, update the documentation.
+4. Ensure the test suite passes.
+5. Make sure your code lints.
+6. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Contributor License Agreement ("CLA")
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Facebook's open source projects.
+
+Complete your CLA here: <https://code.facebook.com/cla>
+
+## Issues
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+## License
+By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
+you agree that your contributions will be licensed under the LICENSE file in
+the root directory of this source tree.
+
+## Pre-commit hooks
+In order to ensure your code lints, there are pre-commit hooks configured in the repository which you can install.
+After installation, they will automatically run each time you commit.
+An abbreviated guide is given below; for more information, refer to [the offical pre-commit documentation](https://pre-commit.com/).
+
+### Installation
+```
+pip install pre-commit
+pre-commit install
+```
+
+### Usage
+Just commit your changes:
+```
+git commit -m "My informative commit message"
+```
+
+If there was a failure, you will get feedback
+```
+[INFO] Initializing environment for https://github.com/PyCQA/flake8.
+[INFO] Installing environment for https://github.com/pre-commit/pre-commit-hooks.
+[INFO] Once installed this environment will be reused.
+[INFO] This may take a few minutes...
+[INFO] Installing environment for https://github.com/PyCQA/flake8.
+[INFO] Once installed this environment will be reused.
+[INFO] This may take a few minutes...
+Trim Trailing Whitespace.................................................Failed
+- hook id: trailing-whitespace
+- exit code: 1
+- files were modified by this hook
+Fixing examples/nllb/modeling/wmt15_benchmark/eval_langs2.sh
+Fix End of Files.........................................................Failed
+- hook id: end-of-file-fixer
+- exit code: 1
+- files were modified by this hook
+Fixing examples/few_shot/scripts/schedule_jobs_few_shot.py
+flake8...................................................................Passed
+```
+
+Certain hooks modify your files to comply.
+To include these modifications, you will need to add them (i.e. `git add ...`) and commit again.
+
+If all is well, you should see something like:
+```
+Trim Trailing Whitespace.................................................Passed
+Fix End of Files.........................................................Passed
+flake8...................................................................Passed
+[gshard-fix-ci 8698644e1] Fix lint, add pre-commit hooks
+ 10 files changed, 148 insertions(+), 110 deletions(-)
+ create mode 100644 .flake8
+ create mode 100644 .pre-commit-config.yaml
+ rename examples/nllb/modeling/wmt15_benchmark/{eval_langs2.py => eval_langs2.sh} (99%)
+ ```

+ 21 - 0
LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) Facebook, Inc. and its affiliates.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 1 - 0
MANIFEST.in

@@ -0,0 +1 @@
+include fairseq/version.txt

+ 242 - 0
README.md

@@ -0,0 +1,242 @@
+<p align="center">
+  <img src="docs/fairseq_logo.png" width="150">
+  <br />
+  <br />
+  <a href="https://opensource.fb.com/support-ukraine"><img alt="Support Ukraine" src="https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB" /></a>
+  <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
+  <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
+  <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
+  <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
+  <a href="https://app.circleci.com/pipelines/github/facebookresearch/fairseq/"><img alt="CicleCI Status" src="https://circleci.com/gh/facebookresearch/fairseq.svg?style=shield" /></a>
+</p>
+
+--------------------------------------------------------------------------------
+
+Fairseq(-py) is a sequence modeling toolkit that allows researchers and
+developers to train custom models for translation, summarization, language
+modeling and other text generation tasks.
+
+We provide reference implementations of various sequence modeling papers:
+
+<details><summary>List of implemented papers</summary><p>
+
+* **Convolutional Neural Networks (CNN)**
+  + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
+  + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
+  + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
+  + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
+  + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
+* **LightConv and DynamicConv models**
+  + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
+* **Long Short-Term Memory (LSTM) networks**
+  + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
+* **Transformer (self-attention) networks**
+  + Attention Is All You Need (Vaswani et al., 2017)
+  + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
+  + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
+  + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
+  + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
+  + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
+  + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
+  + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
+  + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
+  + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
+  + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
+  + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
+  + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
+  + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
+  + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
+  + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
+  + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
+  + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
+  + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
+  + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
+  + [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430)
+  + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
+  + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
+  + [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680)
+  + [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf)
+  + [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf)
+  + [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md)
+* **Non-autoregressive Transformers**
+  + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
+  + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
+  + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
+  + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
+  + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
+* **Finetuning**
+  + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
+
+</p></details>
+
+### What's New:
+* May 2023 [Released models for Scaling Speech Technology to 1,000+ Languages  (Pratap, et al., 2023)](examples/mms/README.md)
+* June 2022 [Released code for wav2vec-U 2.0 from Towards End-to-end Unsupervised Speech Recognition (Liu, et al., 2022)](examples/wav2vec/unsupervised/README.md)
+* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers)
+* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md)
+* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
+* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)
+* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
+* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
+* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
+* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
+* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
+* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
+* February 2021 [Added LASER training code](examples/laser/README.md)
+* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
+* December 2020: [GottBERT model and code released](examples/gottbert/README.md)
+* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
+  * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
+* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
+* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
+* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
+* October 2020: [Added CRISS models and code](examples/criss/README.md)
+
+<details><summary>Previous updates</summary><p>
+
+* September 2020: [Added Linformer code](examples/linformer/README.md)
+* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
+* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
+* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
+* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
+* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
+* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
+* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
+* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
+* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
+* February 2020: [mBART model and code released](examples/mbart/README.md)
+* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
+* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
+* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
+* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
+* November 2019: [BART model and code released](examples/bart/README.md)
+* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
+* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
+* August 2019: [WMT'19 models released](examples/wmt19/README.md)
+* July 2019: fairseq relicensed under MIT license
+* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
+* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
+
+</p></details>
+
+### Features:
+
+* multi-GPU training on one machine or across multiple machines (data and model parallel)
+* fast generation on both CPU and GPU with multiple search algorithms implemented:
+  + beam search
+  + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
+  + sampling (unconstrained, top-k and top-p/nucleus)
+  + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
+* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
+* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
+* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
+* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
+* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
+* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
+
+We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
+with a convenient `torch.hub` interface:
+
+``` python
+en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
+en2de.translate('Hello world', beam=5)
+# 'Hallo Welt'
+```
+
+See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
+and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
+
+# Requirements and Installation
+
+* [PyTorch](http://pytorch.org/) version >= 1.10.0
+* Python version >= 3.8
+* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
+* **To install fairseq** and develop locally:
+
+``` bash
+git clone https://github.com/pytorch/fairseq
+cd fairseq
+pip install --editable ./
+
+# on MacOS:
+# CFLAGS="-stdlib=libc++" pip install --editable ./
+
+# to install the latest stable release (0.10.x)
+# pip install fairseq
+```
+
+* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
+
+``` bash
+git clone https://github.com/NVIDIA/apex
+cd apex
+pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
+  --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
+  --global-option="--fast_multihead_attn" ./
+```
+
+* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
+* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
+ as command line options to `nvidia-docker run` .
+
+# Getting Started
+
+The [full documentation](https://fairseq.readthedocs.io/) contains instructions
+for getting started, training new models and extending fairseq with new model
+types and tasks.
+
+# Pre-trained models and examples
+
+We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
+as well as example training and evaluation commands.
+
+* [Translation](examples/translation/README.md): convolutional and transformer models are available
+* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
+
+We also have more detailed READMEs to reproduce results from specific papers:
+
+* [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md)
+* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
+* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
+* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
+* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
+* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
+* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
+* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
+* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
+* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
+* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
+* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
+* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
+* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
+* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
+* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
+* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
+* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
+* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
+* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
+* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
+
+# Join the fairseq community
+
+* Twitter: https://twitter.com/fairseq
+* Facebook page: https://www.facebook.com/groups/fairseq.users
+* Google group: https://groups.google.com/forum/#!forum/fairseq-users
+
+# License
+
+fairseq(-py) is MIT-licensed.
+The license applies to the pre-trained models as well.
+
+# Citation
+
+Please cite as:
+
+``` bibtex
+@inproceedings{ott2019fairseq,
+  title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
+  author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
+  booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
+  year = {2019},
+}
+```

+ 13 - 0
RELEASE.md

@@ -0,0 +1,13 @@
+# Creating a New Release
+
+In order to create a new release:
+
+1. Navigate to the [Fairseq Workflows](https://github.com/facebookresearch/fairseq/actions) and find the one named _Fairseq Release_. 
+
+2. Under _Run Workflow_ choose the branch `main` and for _Release Type_ enter either `major`, `minor`, or `patch`.  
+
+3. A branch named `$new_version-release` will be created where the `version.txt` file is updated. Merge those changes into `main`.
+
+4. Make sure that a [new PYPI package](https://pypi.org/project/fairseq/) has been uploaded.
+
+5. Make sure that a [new github release](https://github.com/facebookresearch/fairseq/releases) has been created.

+ 20 - 0
docs/Makefile

@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line.
+SPHINXOPTS    =
+SPHINXBUILD   = python -msphinx
+SPHINXPROJ    = fairseq
+SOURCEDIR     = .
+BUILDDIR      = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

+ 85 - 0
docs/command_line_tools.rst

@@ -0,0 +1,85 @@
+.. _Command-line Tools:
+
+Command-line Tools
+==================
+
+Fairseq provides several command-line tools for training and evaluating models:
+
+- :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
+- :ref:`fairseq-train`: Train a new model on one or multiple GPUs
+- :ref:`fairseq-generate`: Translate pre-processed data with a trained model
+- :ref:`fairseq-interactive`: Translate raw text with a trained model
+- :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
+- :ref:`fairseq-eval-lm`: Language model evaluation
+
+
+.. _fairseq-preprocess:
+
+fairseq-preprocess
+~~~~~~~~~~~~~~~~~~
+.. automodule:: fairseq_cli.preprocess
+
+    .. argparse::
+        :module: fairseq.options
+        :func: get_preprocessing_parser
+        :prog: fairseq-preprocess
+
+
+.. _fairseq-train:
+
+fairseq-train
+~~~~~~~~~~~~~
+.. automodule:: fairseq_cli.train
+
+    .. argparse::
+        :module: fairseq.options
+        :func: get_training_parser
+        :prog: fairseq-train
+
+
+.. _fairseq-generate:
+
+fairseq-generate
+~~~~~~~~~~~~~~~~
+.. automodule:: fairseq_cli.generate
+
+    .. argparse::
+        :module: fairseq.options
+        :func: get_generation_parser
+        :prog: fairseq-generate
+
+
+.. _fairseq-interactive:
+
+fairseq-interactive
+~~~~~~~~~~~~~~~~~~~
+.. automodule:: fairseq_cli.interactive
+
+    .. argparse::
+        :module: fairseq.options
+        :func: get_interactive_generation_parser
+        :prog: fairseq-interactive
+
+
+.. _fairseq-score:
+
+fairseq-score
+~~~~~~~~~~~~~
+.. automodule:: fairseq_cli.score
+
+    .. argparse::
+        :module: fairseq_cli.score
+        :func: get_parser
+        :prog: fairseq-score
+
+
+.. _fairseq-eval-lm:
+
+fairseq-eval-lm
+~~~~~~~~~~~~~~~
+.. automodule:: fairseq_cli.eval_lm
+
+    .. argparse::
+        :module: fairseq.options
+        :func: get_eval_lm_parser
+        :prog: fairseq-eval-lm

+ 98 - 0
docs/conf.py

@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# fairseq documentation build configuration file, created by
+# sphinx-quickstart on Fri Aug 17 21:45:30 2018.
+#
+# This file is execfile()d with the current directory set to its
+# containing dir.
+#
+# Note that not all possible configuration values are present in this
+# autogenerated file.
+#
+# All configuration values have a default; values that are commented out
+# serve to show the default.
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import os
+import sys
+from fairseq import __version__
+
+
+# source code directory, relative to this file, for sphinx-autobuild
+sys.path.insert(0, os.path.abspath(".."))
+
+source_suffix = [".rst"]
+
+# -- General configuration ------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+#
+# needs_sphinx = '1.0'
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+    "sphinx.ext.autodoc",
+    "sphinx.ext.intersphinx",
+    "sphinx.ext.viewcode",
+    "sphinx.ext.napoleon",
+    "sphinxarg.ext",
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# The master toctree document.
+master_doc = "index"
+
+# General information about the project.
+project = "fairseq"
+copyright = "Facebook AI Research (FAIR)"
+author = "Facebook AI Research (FAIR)"
+
+github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/"
+
+# The version info for the project you're documenting, acts as replacement for
+# |version| and |release|, also used in various other places throughout the
+# built documents.
+#
+# The short X.Y version.
+version = __version__
+# The full version, including alpha/beta/rc tags.
+release = __version__
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This patterns also effect to html_static_path and html_extra_path
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = "sphinx"
+highlight_language = "python"
+
+# If true, `todo` and `todoList` produce output, else they produce nothing.
+todo_include_todos = False
+
+
+# -- Options for HTML output ----------------------------------------------
+
+html_theme = "classic"
+
+# Example configuration for intersphinx: refer to the Python standard library.
+intersphinx_mapping = {
+    "numpy": ("http://docs.scipy.org/doc/numpy/", None),
+    "python": ("https://docs.python.org/", None),
+    "torch": ("https://pytorch.org/docs/master/", None),
+}

+ 31 - 0
docs/criterions.rst

@@ -0,0 +1,31 @@
+.. role:: hidden
+    :class: hidden-section
+
+.. _Criterions:
+
+Criterions
+==========
+
+Criterions compute the loss function given the model and batch, roughly::
+
+  loss = criterion(model, batch)
+
+.. automodule:: fairseq.criterions
+    :members:
+
+.. autoclass:: fairseq.criterions.FairseqCriterion
+    :members:
+    :undoc-members:
+
+.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
+    :members:
+    :undoc-members:

+ 58 - 0
docs/data.rst

@@ -0,0 +1,58 @@
+.. role:: hidden
+    :class: hidden-section
+
+.. module:: fairseq.data
+
+Data Loading and Utilities
+==========================
+
+.. _datasets:
+
+Datasets
+--------
+
+**Datasets** define the data format and provide helpers for creating
+mini-batches.
+
+.. autoclass:: fairseq.data.FairseqDataset
+    :members:
+.. autoclass:: fairseq.data.LanguagePairDataset
+    :members:
+.. autoclass:: fairseq.data.MonolingualDataset
+    :members:
+
+**Helper Datasets**
+
+These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
+provide additional functionality:
+
+.. autoclass:: fairseq.data.BacktranslationDataset
+    :members:
+.. autoclass:: fairseq.data.ConcatDataset
+    :members:
+.. autoclass:: fairseq.data.ResamplingDataset
+    :members:
+.. autoclass:: fairseq.data.RoundRobinZipDatasets
+    :members:
+.. autoclass:: fairseq.data.TransformEosDataset
+    :members:
+
+
+Dictionary
+----------
+
+.. autoclass:: fairseq.data.Dictionary
+    :members:
+
+
+Iterators
+---------
+
+.. autoclass:: fairseq.data.CountingIterator
+    :members:
+.. autoclass:: fairseq.data.EpochBatchIterator
+    :members:
+.. autoclass:: fairseq.data.GroupedIterator
+    :members:
+.. autoclass:: fairseq.data.ShardedIterator
+    :members:

+ 2 - 0
docs/docutils.conf

@@ -0,0 +1,2 @@
+[writers]
+option-limit=0

BIN
docs/fairseq.gif


BIN
docs/fairseq_logo.png


+ 216 - 0
docs/getting_started.rst

@@ -0,0 +1,216 @@
+Evaluating Pre-trained Models
+=============================
+
+First, download a pre-trained model along with its vocabularies:
+
+.. code-block:: console
+
+    > curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
+
+This model uses a `Byte Pair Encoding (BPE)
+vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply
+the encoding to the source text before it can be translated. This can be
+done with the
+`apply\_bpe.py <https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/apply_bpe.py>`__
+script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
+used as a continuation marker and the original text can be easily
+recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
+flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
+using ``tokenizer.perl`` from
+`mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
+
+Let's use :ref:`fairseq-interactive` to generate translations interactively.
+Here, we use a beam size of 5 and preprocess the input with the Moses
+tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
+remove the BPE continuation markers and detokenize the output.
+
+.. code-block:: console
+
+    > MODEL_DIR=wmt14.en-fr.fconv-py
+    > fairseq-interactive \
+        --path $MODEL_DIR/model.pt $MODEL_DIR \
+        --beam 5 --source-lang en --target-lang fr \
+        --tokenizer moses \
+        --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
+    | loading model(s) from wmt14.en-fr.fconv-py/model.pt
+    | [en] dictionary: 44206 types
+    | [fr] dictionary: 44463 types
+    | Type the input sentence and press return:
+    Why is it rare to discover new marine mammal species?
+    S-0     Why is it rare to discover new marine mam@@ mal species ?
+    H-0     -0.0643349438905716     Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
+    P-0     -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
+
+This generation script produces three types of outputs: a line prefixed
+with *O* is a copy of the original source sentence; *H* is the
+hypothesis along with an average log-likelihood; and *P* is the
+positional score per token position, including the
+end-of-sentence marker which is omitted from the text.
+
+Other types of output lines you might see are *D*, the detokenized hypothesis,
+*T*, the reference target, *A*, alignment info, *E* the history of generation steps.
+
+See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
+full list of pre-trained models available.
+
+Training a New Model
+====================
+
+The following tutorial is for machine translation. For an example of how
+to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
+``examples/`` directory.
+
+Data Pre-processing
+-------------------
+
+Fairseq contains example pre-processing scripts for several translation
+datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
+2014 (English-German). To pre-process and binarize the IWSLT dataset:
+
+.. code-block:: console
+
+    > cd examples/translation/
+    > bash prepare-iwslt14.sh
+    > cd ../..
+    > TEXT=examples/translation/iwslt14.tokenized.de-en
+    > fairseq-preprocess --source-lang de --target-lang en \
+        --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
+        --destdir data-bin/iwslt14.tokenized.de-en
+
+This will write binarized data that can be used for model training to
+``data-bin/iwslt14.tokenized.de-en``.
+
+Training
+--------
+
+Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
+well for the IWSLT 2014 dataset:
+
+.. code-block:: console
+
+    > mkdir -p checkpoints/fconv
+    > CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
+        --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
+        --arch fconv_iwslt_de_en --save-dir checkpoints/fconv
+
+By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
+``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
+change the number of GPU devices that will be used.
+
+Also note that the batch size is specified in terms of the maximum
+number of tokens per batch (``--max-tokens``). You may need to use a
+smaller value depending on the available GPU memory on your system.
+
+Generation
+----------
+
+Once your model is trained, you can generate translations using
+:ref:`fairseq-generate` **(for binarized data)** or
+:ref:`fairseq-interactive` **(for raw text)**:
+
+.. code-block:: console
+
+    > fairseq-generate data-bin/iwslt14.tokenized.de-en \
+        --path checkpoints/fconv/checkpoint_best.pt \
+        --batch-size 128 --beam 5
+    | [de] dictionary: 35475 types
+    | [en] dictionary: 24739 types
+    | data-bin/iwslt14.tokenized.de-en test 6750 examples
+    | model fconv
+    | loaded checkpoint trainings/fconv/checkpoint_best.pt
+    S-721   danke .
+    T-721   thank you .
+    ...
+
+To generate translations with only a CPU, use the ``--cpu`` flag. BPE
+continuation markers can be removed with the ``--remove-bpe`` flag.
+
+Advanced Training Options
+=========================
+
+Large mini-batch training with delayed updates
+----------------------------------------------
+
+The ``--update-freq`` option can be used to accumulate gradients from
+multiple mini-batches and delay updating, creating a larger effective
+batch size. Delayed updates can also improve training speed by reducing
+inter-GPU communication costs and by saving idle time caused by variance
+in workload across GPUs. See `Ott et al.
+(2018) <https://arxiv.org/abs/1806.00187>`__ for more details.
+
+To train on a single GPU with an effective batch size that is equivalent
+to training on 8 GPUs:
+
+.. code-block:: console
+
+    > CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
+
+Training with half precision floating point (FP16)
+--------------------------------------------------
+
+.. note::
+
+    FP16 training requires a Volta GPU and CUDA 9.1 or greater
+
+Recent GPUs enable efficient half precision floating point computation,
+e.g., using `Nvidia Tensor Cores
+<https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html>`__.
+Fairseq supports FP16 training with the ``--fp16`` flag:
+
+.. code-block:: console
+
+    > fairseq-train --fp16 (...)
+
+Distributed training
+--------------------
+
+Distributed training in fairseq is implemented on top of ``torch.distributed``.
+The easiest way to launch jobs is with the `torch.distributed.launch
+<https://pytorch.org/docs/stable/distributed.html#launch-utility>`__ tool.
+
+For example, to train a large English-German Transformer model on 2 nodes each
+with 8 GPUs (in total 16 GPUs), run the following command on each node,
+replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
+sure to update ``--master_addr`` to the IP address of the first node:
+
+.. code-block:: console
+
+    > python -m torch.distributed.launch --nproc_per_node=8 \
+        --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
+        --master_port=12345 \
+        $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
+        --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
+        --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
+        --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
+        --lr 0.0005 \
+        --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
+        --max-tokens 3584 \
+        --max-epoch 70 \
+        --fp16
+
+On SLURM clusters, fairseq will automatically detect the number of nodes and
+GPUs, but a port number must be provided:
+
+.. code-block:: console
+
+    > salloc --gpus=16 --nodes 2 (...)
+    > srun fairseq-train --distributed-port 12345 (...).
+
+Sharding very large datasets
+----------------------------
+
+It can be challenging to train over very large datasets, particularly if your
+machine does not have much system RAM. Most tasks in fairseq support training
+over "sharded" datasets, in which the original dataset has been preprocessed
+into non-overlapping chunks (or "shards").
+
+For example, instead of preprocessing all your data into a single "data-bin"
+directory, you can split the data and create "data-bin1", "data-bin2", etc.
+Then you can adapt your training command like so:
+
+.. code-block:: console
+
+    > fairseq-train data-bin1:data-bin2:data-bin3 (...)
+
+Training will now iterate over each shard, one by one, with each shard
+corresponding to an "epoch", thus reducing system memory usage.

+ 284 - 0
docs/hydra_integration.md

@@ -0,0 +1,284 @@
+## Hydra
+
+[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
+framework that simplifies the development of research and other complex
+applications. The key feature is the ability to dynamically create a
+hierarchical configuration by composition and override it through config files
+and the command line. The name Hydra comes from its ability to run multiple
+similar jobs - much like a Hydra with multiple heads.
+
+## Motivation
+
+Until recently, all components in fairseq were configured through a shared
+`args` namespace that was created at application startup. Components declared
+their own `add_args` method to update the argparse parser, hoping that the names
+would not clash with arguments from other components. While this model works for
+smaller applications, as fairseq grew and became integrated into other
+applications, this became problematic. In order to determine how to configure
+each component, one needed to a) examine what args were added by this component,
+and b) read the code to figure out what shared arguments it is using that were
+added in other places. Reproducing models involved sharing commands that often
+contained dozens of command line switches.
+
+The model described above is still supported by fairseq for backward
+compatibility, but will be deprecated some time in the future.
+
+New components in fairseq should now create a dataclass that encapsulates all
+parameters required to configure this component. The dataclass is registered
+along with the component, and fairseq takes care of constructing and providing
+this configuration object to the component's constructor. Note that sharing
+parameters can optionally still work, but one has to explicitly point to the
+"source of truth" (see inheritance example below). These changes make components
+in fairseq more independent and re-usable by other applications: all that is
+needed to create a component is to initialize its dataclass and overwrite some
+of the defaults.
+
+While configuring fairseq through command line (using either the legacy argparse
+based or the new Hydra based entry points) is still fully supported, you can now
+take advantage of configuring fairseq completely or piece-by-piece through
+hierarchical YAML configuration files. These files can also be shipped as
+examples that others can use to run an identically configured job.
+
+Additionally, Hydra has a rich and growing [library of
+plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
+provide functionality such as hyperparameter sweeping (including using bayesian
+optimization through the [Ax](https://github.com/facebook/Ax) library), job
+launching across various platforms, and more.
+
+## Creating or migrating components
+
+In general, each new (or updated) component should provide a companion
+[dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
+typically located in the same file as the component and are passed as arguments
+to the `register_*()` functions. Top-level configs that should be present in
+every fairseq application are placed in the
+[global](fairseq/dataclass/configs.py) config file and added to the
+`FairseqConfig` object.
+
+Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
+classes are decorated with a `@dataclass` decorator, and typically inherit from
+`FairseqDataclass` (which adds some functionality for backward compatibility).
+Each field must have a type, and generally has metadata (such as a help string)
+and a default value. Only primitive types or other config objects are allowed as
+data types for each field.
+
+#### Example:
+
+```python
+from dataclasses import dataclass, field
+from fairseq.dataclass import FairseqDataclass
+
+@dataclass
+class InteractiveConfig(FairseqDataclass):
+    buffer_size: int = field(
+        default=0,
+        metadata={
+            "help": "read this many sentences into a buffer before processing them"
+        },
+    )
+    input: str = field(
+        default="-",
+        metadata={"help": "file to read from; use - for stdin"},
+    )
+```
+
+### Inherting values
+
+Some components require sharing a value. For example, a learning rate scheduler
+and an optimizer may both need to know the initial learning rate value. One can
+declare a field that, by default, will inherit its value from another config
+node in the same hierarchy:
+
+```python
+@dataclass
+FairseqAdamConfig(FairseqDataclass):
+    ...
+    lr: List[float] = II("optimization.lr")
+    ...
+```
+
+`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
+the value one can use in a YAML config file or through command line to achieve
+the same effect. Note that this assumes that there is an "optimization" config
+object in the root config and it has a field called "lr".
+
+### Tasks and Models
+
+Creating Tasks and Models works same as before, except that legacy
+implementations now inherit from `LegacyFairseq*` base classes, while new
+components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
+to the `register_*()` functions.
+
+#### Task example:
+
+```python
+@dataclass
+class LanguageModelingConfig(FairseqDataclass):
+    data: Optional[str] = field(
+        default=None, metadata={"help": "path to data directory"}
+    )
+    ...
+
+@register_task("language_modeling", dataclass=LanguageModelingConfig)
+class LanguageModelingTask(FairseqTask):
+    ...
+    @classmethod
+    def setup_task(cls, cfg: LanguageModelingConfig):
+        ...
+```
+
+#### Model example:
+
+```python
+@dataclass
+class TransformerLanguageModelConfig(FairseqDataclass):
+    activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
+        default="relu", metadata={"help": "activation function to use"}
+    )
+    dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
+    ...
+
+@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
+class TransformerLanguageModel(FairseqLanguageModel):
+    ...
+    @classmethod
+    def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
+        ...
+```
+
+### Other components
+
+Other components work as before, but they now take their configuration dataclass
+as the only constructor argument:
+
+```python
+@dataclass
+class MosesTokenizerConfig(FairseqDataclass):
+    source_lang: str = field(default="en", metadata={"help": "source language"})
+    ...
+
+@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
+class MosesTokenizer(object):
+    def __init__(self, cfg: MosesTokenizerConfig):
+        ...
+```
+
+Note that if you are adding a new registry for a new set of components, you need
+to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
+
+```python
+@dataclass
+class FairseqConfig(object):
+    ...
+    my_new_registry: Any = None
+```
+
+## Training with `fairseq-hydra-train`
+
+To fully take advantage of configuration flexibility offered by Hydra, you may
+want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
+tools such as `fairseq-train` will remain supported for the foreseeable future
+but will be deprecated eventually.
+
+On startup, Hydra will create a configuration object that contains a hierarchy
+of all the necessary dataclasses populated with their default values in the
+code. The default values are overwritten by values found in YAML files in
+`fairseq/config` directory (which currently sets minimal defaults) and then
+further overwritten by values provided through command line arguments.
+
+Some of the most common use cases are shown below:
+
+### 1. Override default values through command line:
+
+```shell script
+$ fairseq-hydra-train \
+    distributed_training.distributed_world_size=1 \
+    dataset.batch_size=2 \
+    task.data=data-bin \
+    model=transformer_lm/transformer_lm_gpt \
+    task=language_modeling \
+    optimization.max_update=5000
+```
+
+Note that along with explicitly providing values for parameters such as
+`dataset.batch_size`, this also tells Hydra to overlay configuration found in
+`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
+values in the dataclass. If you want to train a model without specifying a
+particular architecture you can simply specify `model=transformer_lm`. This only
+works for migrated tasks and models.
+
+### 2. Replace bundled configs with an external config:
+
+```shell script
+$ fairseq-hydra-train \
+    --config-dir /path/to/external/configs \
+    --config-name wiki103
+```
+
+where `/path/to/external/configs/wiki103.yaml` contains:
+
+```yaml
+# @package _group_
+
+model:
+  _name: transformer_lm
+distributed_training:
+  distributed_world_size: 1
+dataset:
+  batch_size: 2
+task:
+  _name: language_modeling
+  data: /path/to/data
+  add_bos_token: false
+  max_target_positions: 1024
+optimization:
+  max_update: 50000
+  lr: [ 0.25 ]
+criterion: cross_entropy
+optimizer: adam
+lr_scheduler:
+  _name: cosine
+```
+
+Note that here bundled configs from `fairseq/config` directory are not used,
+however the defaults from each dataclass will still be used (unless overwritten
+by your external config).
+
+Additionally you can choose to break up your configs by creating a directory
+structure in the same location as your main config file, with the names of the
+top-level fields (such as "model", "dataset", etc), and placing config files
+with meaningful names that would populate that specific section of your
+top-level config file (for example, you might have
+`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
+can then specify the correct configuration via command line, defaults in the
+main config, or even launch all of them as a sweep (see Hydra documentation on
+how to do this).
+
+### 3. Add an external config directory to Hydra search path:
+
+This allows combining default configuration (including using any bundled config
+files), while specifying your own config files for some parts of the
+configuration.
+
+```shell script
+$ fairseq-hydra-train \
+    distributed_training.distributed_world_size=1 \
+    dataset.batch_size=2 \
+    task.data=/path/to/data/ \
+    model=transformer_lm/2_layers \
+    task=language_modeling \
+    optimization.max_update=5000 \
+    --config-dir /path/to/external/configs
+```
+
+where `/path/to/external/configs` has the following structure:
+```
+.
++-- model
+|   +-- transformer_lm
+|   |   +-- 2_layers.yaml
+```
+
+and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
+`decoder_layers` set to 2. You can add other configs to configure other
+components as well.

+ 49 - 0
docs/index.rst

@@ -0,0 +1,49 @@
+.. fairseq documentation master file, created by
+   sphinx-quickstart on Fri Aug 17 21:45:30 2018.
+   You can adapt this file completely to your liking, but it should at least
+   contain the root `toctree` directive.
+
+:github_url: https://github.com/pytorch/fairseq
+
+
+fairseq documentation
+=====================
+
+Fairseq is a sequence modeling toolkit written in `PyTorch
+<http://pytorch.org/>`_ that allows researchers and developers to
+train custom models for translation, summarization, language modeling and other
+text generation tasks.
+
+.. toctree::
+    :maxdepth: 1
+    :caption: Getting Started
+
+    getting_started
+    command_line_tools
+
+.. toctree::
+    :maxdepth: 1
+    :caption: Extending Fairseq
+
+    overview
+    tutorial_simple_lstm
+    tutorial_classifying_names
+
+.. toctree::
+    :maxdepth: 2
+    :caption: Library Reference
+
+    tasks
+    models
+    criterions
+    optim
+    lr_scheduler
+    data
+    modules
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`search`

+ 34 - 0
docs/lr_scheduler.rst

@@ -0,0 +1,34 @@
+.. role:: hidden
+    :class: hidden-section
+
+.. _Learning Rate Schedulers:
+
+Learning Rate Schedulers
+========================
+
+Learning Rate Schedulers update the learning rate over the course of training.
+Learning rates can be updated after each update via :func:`step_update` or at
+epoch boundaries via :func:`step`.
+
+.. automodule:: fairseq.optim.lr_scheduler
+    :members:
+
+.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
+    :members:
+    :undoc-members:
+
+.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
+    :members:
+    :undoc-members:

+ 36 - 0
docs/make.bat

@@ -0,0 +1,36 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+	set SPHINXBUILD=python -msphinx
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+set SPHINXPROJ=fairseq
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+	echo.
+	echo.The Sphinx module was not found. Make sure you have Sphinx installed,
+	echo.then set the SPHINXBUILD environment variable to point to the full
+	echo.path of the 'sphinx-build' executable. Alternatively you may add the
+	echo.Sphinx directory to PATH.
+	echo.
+	echo.If you don't have Sphinx installed, grab it from
+	echo.http://sphinx-doc.org/
+	exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
+
+:end
+popd

+ 104 - 0
docs/models.rst

@@ -0,0 +1,104 @@
+.. role:: hidden
+    :class: hidden-section
+
+.. module:: fairseq.models
+
+.. _Models:
+
+Models
+======
+
+A Model defines the neural network's ``forward()`` method and encapsulates all
+of the learnable parameters in the network. Each model also provides a set of
+named *architectures* that define the precise network configuration (e.g.,
+embedding dimension, number of layers, etc.).
+
+Both the model type and architecture are selected via the ``--arch``
+command-line argument. Once selected, a model may expose additional command-line
+arguments for further configuration.
+
+.. note::
+
+    All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
+    :class:`torch.nn.Module`. Thus any fairseq Model can be used as a
+    stand-alone Module in other PyTorch code.
+
+
+Convolutional Neural Networks (CNN)
+-----------------------------------
+
+.. module:: fairseq.models.fconv
+.. autoclass:: fairseq.models.fconv.FConvModel
+    :members:
+.. autoclass:: fairseq.models.fconv.FConvEncoder
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.models.fconv.FConvDecoder
+    :members:
+
+
+Long Short-Term Memory (LSTM) networks
+--------------------------------------
+
+.. module:: fairseq.models.lstm
+.. autoclass:: fairseq.models.lstm.LSTMModel
+    :members:
+.. autoclass:: fairseq.models.lstm.LSTMEncoder
+    :members:
+.. autoclass:: fairseq.models.lstm.LSTMDecoder
+    :members:
+
+
+Transformer (self-attention) networks
+-------------------------------------
+
+.. module:: fairseq.models.transformer
+.. autoclass:: fairseq.models.transformer.TransformerModel
+    :members:
+.. autoclass:: fairseq.models.transformer.TransformerEncoder
+    :members:
+.. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
+    :members:
+.. autoclass:: fairseq.models.transformer.TransformerDecoder
+    :members:
+.. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
+    :members:
+
+
+Adding new models
+-----------------
+
+.. currentmodule:: fairseq.models
+.. autofunction:: fairseq.models.register_model
+.. autofunction:: fairseq.models.register_model_architecture
+.. autoclass:: fairseq.models.BaseFairseqModel
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.models.FairseqEncoderDecoderModel
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.models.FairseqEncoderModel
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.models.FairseqLanguageModel
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.models.FairseqMultiModel
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.models.FairseqEncoder
+    :members:
+.. autoclass:: fairseq.models.CompositeEncoder
+    :members:
+.. autoclass:: fairseq.models.FairseqDecoder
+    :members:
+
+
+.. _Incremental decoding:
+
+Incremental decoding
+--------------------
+
+.. autoclass:: fairseq.models.FairseqIncrementalDecoder
+    :members:
+    :undoc-members:

+ 9 - 0
docs/modules.rst

@@ -0,0 +1,9 @@
+Modules
+=======
+
+Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
+be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
+
+.. automodule:: fairseq.modules
+    :members:
+    :undoc-members:

+ 38 - 0
docs/optim.rst

@@ -0,0 +1,38 @@
+.. role:: hidden
+    :class: hidden-section
+
+.. _optimizers:
+
+Optimizers
+==========
+
+Optimizers update the Model parameters based on the gradients.
+
+.. automodule:: fairseq.optim
+    :members:
+
+.. autoclass:: fairseq.optim.FairseqOptimizer
+    :members:
+    :undoc-members:
+
+.. autoclass:: fairseq.optim.adadelta.Adadelta
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.adagrad.Adagrad
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.adam.FairseqAdam
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.nag.FairseqNAG
+    :members:
+    :undoc-members:
+.. autoclass:: fairseq.optim.sgd.SGD
+    :members:
+    :undoc-members:

+ 74 - 0
docs/overview.rst

@@ -0,0 +1,74 @@
+Overview
+========
+
+Fairseq can be extended through user-supplied `plug-ins
+<https://en.wikipedia.org/wiki/Plug-in_(computing)>`_. We support five kinds of
+plug-ins:
+
+- :ref:`Models` define the neural network architecture and encapsulate all of the
+  learnable parameters.
+- :ref:`Criterions` compute the loss function given the model outputs and targets.
+- :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
+  Datasets, initializing the Model/Criterion and calculating the loss.
+- :ref:`Optimizers` update the Model parameters based on the gradients.
+- :ref:`Learning Rate Schedulers` update the learning rate over the course of
+  training.
+
+**Training Flow**
+
+Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
+fairseq implements the following high-level training flow::
+
+  for epoch in range(num_epochs):
+      itr = task.get_batch_iterator(task.dataset('train'))
+      for num_updates, batch in enumerate(itr):
+          task.train_step(batch, model, criterion, optimizer)
+          average_and_clip_gradients()
+          optimizer.step()
+          lr_scheduler.step_update(num_updates)
+      lr_scheduler.step(epoch)
+
+where the default implementation for ``task.train_step`` is roughly::
+
+  def train_step(self, batch, model, criterion, optimizer, **unused):
+      loss = criterion(model, batch)
+      optimizer.backward(loss)
+      return loss
+
+**Registering new plug-ins**
+
+New plug-ins are *registered* through a set of ``@register`` function
+decorators, for example::
+
+  @register_model('my_lstm')
+  class MyLSTM(FairseqEncoderDecoderModel):
+      (...)
+
+Once registered, new plug-ins can be used with the existing :ref:`Command-line
+Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
+new plug-ins.
+
+**Loading plug-ins from another directory**
+
+New plug-ins can be defined in a custom module stored in the user system. In
+order to import the module, and make the plugin available to *fairseq*, the
+command line supports the ``--user-dir`` flag that can be used to specify a
+custom location for additional modules to load into *fairseq*.
+
+For example, assuming this directory tree::
+
+  /home/user/my-module/
+  └── __init__.py
+  
+with ``__init__.py``::
+
+  from fairseq.models import register_model_architecture
+  from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
+
+  @register_model_architecture('transformer', 'my_transformer')
+  def transformer_mmt_big(args):
+      transformer_vaswani_wmt_en_de_big(args)
+
+it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
+
+  fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation

+ 61 - 0
docs/tasks.rst

@@ -0,0 +1,61 @@
+.. role:: hidden
+    :class: hidden-section
+
+.. module:: fairseq.tasks
+
+.. _Tasks:
+
+Tasks
+=====
+
+Tasks store dictionaries and provide helpers for loading/iterating over
+Datasets, initializing the Model/Criterion and calculating the loss.
+
+Tasks can be selected via the ``--task`` command-line argument. Once selected, a
+task may expose additional command-line arguments for further configuration.
+
+Example usage::
+
+    # setup the task (e.g., load dictionaries)
+    task = fairseq.tasks.setup_task(args)
+
+    # build model and criterion
+    model = task.build_model(args)
+    criterion = task.build_criterion(args)
+
+    # load datasets
+    task.load_dataset('train')
+    task.load_dataset('valid')
+
+    # iterate over mini-batches of data
+    batch_itr = task.get_batch_iterator(
+        task.dataset('train'), max_tokens=4096,
+    )
+    for batch in batch_itr:
+        # compute the loss
+        loss, sample_size, logging_output = task.get_loss(
+            model, criterion, batch,
+        )
+        loss.backward()
+
+
+Translation
+-----------
+
+.. autoclass:: fairseq.tasks.translation.TranslationTask
+
+.. _language modeling:
+
+Language Modeling
+-----------------
+
+.. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
+
+
+Adding new tasks
+----------------
+
+.. autofunction:: fairseq.tasks.register_task
+.. autoclass:: fairseq.tasks.FairseqTask
+    :members:
+    :undoc-members:

+ 415 - 0
docs/tutorial_classifying_names.rst

@@ -0,0 +1,415 @@
+Tutorial: Classifying Names with a Character-Level RNN
+======================================================
+
+In this tutorial we will extend fairseq to support *classification* tasks. In
+particular we will re-implement the PyTorch tutorial for `Classifying Names with
+a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
+in fairseq. It is recommended to quickly skim that tutorial before beginning
+this one.
+
+This tutorial covers:
+
+1. **Preprocessing the data** to create dictionaries.
+2. **Registering a new Model** that encodes an input sentence with a simple RNN
+   and predicts the output label.
+3. **Registering a new Task** that loads our dictionaries and dataset.
+4. **Training the Model** using the existing command-line tools.
+5. **Writing an evaluation script** that imports fairseq and allows us to
+   interactively evaluate our model on new inputs.
+
+
+1. Preprocessing the data
+-------------------------
+
+The original tutorial provides raw data, but we'll work with a modified version
+of the data that is already tokenized into characters and split into separate
+train, valid and test sets.
+
+Download and extract the data from here:
+`tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
+
+Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
+command-line tool to create the dictionaries. While this tool is primarily
+intended for sequence-to-sequence problems, we're able to reuse it here by
+treating the label as a "target" sequence of length 1. We'll also output the
+preprocessed files in "raw" format using the ``--dataset-impl`` option to
+enhance readability:
+
+.. code-block:: console
+
+  > fairseq-preprocess \
+    --trainpref names/train --validpref names/valid --testpref names/test \
+    --source-lang input --target-lang label \
+    --destdir names-bin --dataset-impl raw
+
+After running the above command you should see a new directory,
+:file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
+
+
+2. Registering a new Model
+--------------------------
+
+Next we'll register a new model in fairseq that will encode an input sentence
+with a simple RNN and predict the output label. Compared to the original PyTorch
+tutorial, our version will also work with batches of data and GPU Tensors.
+
+First let's copy the simple RNN module implemented in the `PyTorch tutorial
+<https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
+Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
+following contents::
+
+    import torch
+    import torch.nn as nn
+
+    class RNN(nn.Module):
+
+        def __init__(self, input_size, hidden_size, output_size):
+            super(RNN, self).__init__()
+
+            self.hidden_size = hidden_size
+
+            self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
+            self.i2o = nn.Linear(input_size + hidden_size, output_size)
+            self.softmax = nn.LogSoftmax(dim=1)
+
+        def forward(self, input, hidden):
+            combined = torch.cat((input, hidden), 1)
+            hidden = self.i2h(combined)
+            output = self.i2o(combined)
+            output = self.softmax(output)
+            return output, hidden
+
+        def initHidden(self):
+            return torch.zeros(1, self.hidden_size)
+
+We must also *register* this model with fairseq using the
+:func:`~fairseq.models.register_model` function decorator. Once the model is
+registered we'll be able to use it with the existing :ref:`Command-line Tools`.
+
+All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
+interface, so we'll create a small wrapper class in the same file and register
+it in fairseq with the name ``'rnn_classifier'``::
+
+    from fairseq.models import BaseFairseqModel, register_model
+
+    # Note: the register_model "decorator" should immediately precede the
+    # definition of the Model class.
+
+    @register_model('rnn_classifier')
+    class FairseqRNNClassifier(BaseFairseqModel):
+
+        @staticmethod
+        def add_args(parser):
+            # Models can override this method to add new command-line arguments.
+            # Here we'll add a new command-line argument to configure the
+            # dimensionality of the hidden state.
+            parser.add_argument(
+                '--hidden-dim', type=int, metavar='N',
+                help='dimensionality of the hidden state',
+            )
+
+        @classmethod
+        def build_model(cls, args, task):
+            # Fairseq initializes models by calling the ``build_model()``
+            # function. This provides more flexibility, since the returned model
+            # instance can be of a different type than the one that was called.
+            # In this case we'll just return a FairseqRNNClassifier instance.
+
+            # Initialize our RNN module
+            rnn = RNN(
+                # We'll define the Task in the next section, but for now just
+                # notice that the task holds the dictionaries for the "source"
+                # (i.e., the input sentence) and "target" (i.e., the label).
+                input_size=len(task.source_dictionary),
+                hidden_size=args.hidden_dim,
+                output_size=len(task.target_dictionary),
+            )
+
+            # Return the wrapped version of the module
+            return FairseqRNNClassifier(
+                rnn=rnn,
+                input_vocab=task.source_dictionary,
+            )
+
+        def __init__(self, rnn, input_vocab):
+            super(FairseqRNNClassifier, self).__init__()
+
+            self.rnn = rnn
+            self.input_vocab = input_vocab
+
+            # The RNN module in the tutorial expects one-hot inputs, so we can
+            # precompute the identity matrix to help convert from indices to
+            # one-hot vectors. We register it as a buffer so that it is moved to
+            # the GPU when ``cuda()`` is called.
+            self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
+
+        def forward(self, src_tokens, src_lengths):
+            # The inputs to the ``forward()`` function are determined by the
+            # Task, and in particular the ``'net_input'`` key in each
+            # mini-batch. We'll define the Task in the next section, but for
+            # now just know that *src_tokens* has shape `(batch, src_len)` and
+            # *src_lengths* has shape `(batch)`.
+            bsz, max_src_len = src_tokens.size()
+
+            # Initialize the RNN hidden state. Compared to the original PyTorch
+            # tutorial we'll also handle batched inputs and work on the GPU.
+            hidden = self.rnn.initHidden()
+            hidden = hidden.repeat(bsz, 1)  # expand for batched inputs
+            hidden = hidden.to(src_tokens.device)  # move to GPU
+
+            for i in range(max_src_len):
+                # WARNING: The inputs have padding, so we should mask those
+                # elements here so that padding doesn't affect the results.
+                # This is left as an exercise for the reader. The padding symbol
+                # is given by ``self.input_vocab.pad()`` and the unpadded length
+                # of each input is given by *src_lengths*.
+
+                # One-hot encode a batch of input characters.
+                input = self.one_hot_inputs[src_tokens[:, i].long()]
+
+                # Feed the input to our RNN.
+                output, hidden = self.rnn(input, hidden)
+
+            # Return the final output state for making a prediction
+            return output
+
+Finally let's define a *named architecture* with the configuration for our
+model. This is done with the :func:`~fairseq.models.register_model_architecture`
+function decorator. Thereafter this named architecture can be used with the
+``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
+
+    from fairseq.models import register_model_architecture
+
+    # The first argument to ``register_model_architecture()`` should be the name
+    # of the model we registered above (i.e., 'rnn_classifier'). The function we
+    # register here should take a single argument *args* and modify it in-place
+    # to match the desired architecture.
+
+    @register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
+    def pytorch_tutorial_rnn(args):
+        # We use ``getattr()`` to prioritize arguments that are explicitly given
+        # on the command-line, so that the defaults defined below are only used
+        # when no other value has been specified.
+        args.hidden_dim = getattr(args, 'hidden_dim', 128)
+
+
+3. Registering a new Task
+-------------------------
+
+Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
+dictionaries and dataset. Tasks can also control how the data is batched into
+mini-batches, but in this tutorial we'll reuse the batching provided by
+:class:`fairseq.data.LanguagePairDataset`.
+
+Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
+following contents::
+
+  import os
+  import torch
+
+  from fairseq.data import Dictionary, LanguagePairDataset
+  from fairseq.tasks import LegacyFairseqTask, register_task
+
+
+  @register_task('simple_classification')
+  class SimpleClassificationTask(LegacyFairseqTask):
+
+      @staticmethod
+      def add_args(parser):
+          # Add some command-line arguments for specifying where the data is
+          # located and the maximum supported input length.
+          parser.add_argument('data', metavar='FILE',
+                              help='file prefix for data')
+          parser.add_argument('--max-positions', default=1024, type=int,
+                              help='max input length')
+
+      @classmethod
+      def setup_task(cls, args, **kwargs):
+          # Here we can perform any setup required for the task. This may include
+          # loading Dictionaries, initializing shared Embedding layers, etc.
+          # In this case we'll just load the Dictionaries.
+          input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
+          label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
+          print('| [input] dictionary: {} types'.format(len(input_vocab)))
+          print('| [label] dictionary: {} types'.format(len(label_vocab)))
+
+          return SimpleClassificationTask(args, input_vocab, label_vocab)
+
+      def __init__(self, args, input_vocab, label_vocab):
+          super().__init__(args)
+          self.input_vocab = input_vocab
+          self.label_vocab = label_vocab
+
+      def load_dataset(self, split, **kwargs):
+          """Load a given dataset split (e.g., train, valid, test)."""
+
+          prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
+
+          # Read input sentences.
+          sentences, lengths = [], []
+          with open(prefix + '.input', encoding='utf-8') as file:
+              for line in file:
+                  sentence = line.strip()
+
+                  # Tokenize the sentence, splitting on spaces
+                  tokens = self.input_vocab.encode_line(
+                      sentence, add_if_not_exist=False,
+                  )
+
+                  sentences.append(tokens)
+                  lengths.append(tokens.numel())
+
+          # Read labels.
+          labels = []
+          with open(prefix + '.label', encoding='utf-8') as file:
+              for line in file:
+                  label = line.strip()
+                  labels.append(
+                      # Convert label to a numeric ID.
+                      torch.LongTensor([self.label_vocab.add_symbol(label)])
+                  )
+
+          assert len(sentences) == len(labels)
+          print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
+
+          # We reuse LanguagePairDataset since classification can be modeled as a
+          # sequence-to-sequence task where the target sequence has length 1.
+          self.datasets[split] = LanguagePairDataset(
+              src=sentences,
+              src_sizes=lengths,
+              src_dict=self.input_vocab,
+              tgt=labels,
+              tgt_sizes=torch.ones(len(labels)),  # targets have length 1
+              tgt_dict=self.label_vocab,
+              left_pad_source=False,
+              # Since our target is a single class label, there's no need for
+              # teacher forcing. If we set this to ``True`` then our Model's
+              # ``forward()`` method would receive an additional argument called
+              # *prev_output_tokens* that would contain a shifted version of the
+              # target sequence.
+              input_feeding=False,
+          )
+
+      def max_positions(self):
+          """Return the max input length allowed by the task."""
+          # The source should be less than *args.max_positions* and the "target"
+          # has max length 1.
+          return (self.args.max_positions, 1)
+
+      @property
+      def source_dictionary(self):
+          """Return the source :class:`~fairseq.data.Dictionary`."""
+          return self.input_vocab
+
+      @property
+      def target_dictionary(self):
+          """Return the target :class:`~fairseq.data.Dictionary`."""
+          return self.label_vocab
+
+      # We could override this method if we wanted more control over how batches
+      # are constructed, but it's not necessary for this tutorial since we can
+      # reuse the batching provided by LanguagePairDataset.
+      #
+      # def get_batch_iterator(
+      #     self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
+      #     ignore_invalid_inputs=False, required_batch_size_multiple=1,
+      #     seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
+      #     data_buffer_size=0, disable_iterator_cache=False,
+      # ):
+      #     (...)
+
+
+4. Training the Model
+---------------------
+
+Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
+command-line tool for this, making sure to specify our new Task (``--task
+simple_classification``) and Model architecture (``--arch
+pytorch_tutorial_rnn``):
+
+.. note::
+
+  You can also configure the dimensionality of the hidden state by passing the
+  ``--hidden-dim`` argument to :ref:`fairseq-train`.
+
+.. code-block:: console
+
+  > fairseq-train names-bin \
+    --task simple_classification \
+    --arch pytorch_tutorial_rnn \
+    --optimizer adam --lr 0.001 --lr-shrink 0.5 \
+    --max-tokens 1000
+  (...)
+  | epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
+  | epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
+  | done training in 31.6 seconds
+
+The model files should appear in the :file:`checkpoints/` directory.
+
+
+5. Writing an evaluation script
+-------------------------------
+
+Finally we can write a short script to evaluate our model on new inputs. Create
+a new file named :file:`eval_classifier.py` with the following contents::
+
+  from fairseq import checkpoint_utils, data, options, tasks
+
+  # Parse command-line arguments for generation
+  parser = options.get_generation_parser(default_task='simple_classification')
+  args = options.parse_args_and_arch(parser)
+
+  # Setup task
+  task = tasks.setup_task(args)
+
+  # Load model
+  print('| loading model from {}'.format(args.path))
+  models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
+  model = models[0]
+
+  while True:
+      sentence = input('\nInput: ')
+
+      # Tokenize into characters
+      chars = ' '.join(list(sentence.strip()))
+      tokens = task.source_dictionary.encode_line(
+          chars, add_if_not_exist=False,
+      )
+
+      # Build mini-batch to feed to the model
+      batch = data.language_pair_dataset.collate(
+          samples=[{'id': -1, 'source': tokens}],  # bsz = 1
+          pad_idx=task.source_dictionary.pad(),
+          eos_idx=task.source_dictionary.eos(),
+          left_pad_source=False,
+          input_feeding=False,
+      )
+
+      # Feed batch to the model and get predictions
+      preds = model(**batch['net_input'])
+
+      # Print top 3 predictions and their log-probabilities
+      top_scores, top_labels = preds[0].topk(k=3)
+      for score, label_idx in zip(top_scores, top_labels):
+          label_name = task.target_dictionary.string([label_idx])
+          print('({:.2f})\t{}'.format(score, label_name))
+
+Now we can evaluate our model interactively. Note that we have included the
+original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
+
+.. code-block:: console
+
+  > python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
+  | [input] dictionary: 64 types
+  | [label] dictionary: 24 types
+  | loading model from checkpoints/checkpoint_best.pt
+
+  Input: Satoshi
+  (-0.61) Japanese
+  (-1.20) Arabic
+  (-2.86) Italian
+
+  Input: Sinbad
+  (-0.30) Arabic
+  (-1.76) English
+  (-4.08) Russian

+ 518 - 0
docs/tutorial_simple_lstm.rst

@@ -0,0 +1,518 @@
+Tutorial: Simple LSTM
+=====================
+
+In this tutorial we will extend fairseq by adding a new
+:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
+sentence with an LSTM and then passes the final hidden state to a second LSTM
+that decodes the target sentence (without attention).
+
+This tutorial covers:
+
+1. **Writing an Encoder and Decoder** to encode/decode the source/target
+   sentence, respectively.
+2. **Registering a new Model** so that it can be used with the existing
+   :ref:`Command-line tools`.
+3. **Training the Model** using the existing command-line tools.
+4. **Making generation faster** by modifying the Decoder to use
+   :ref:`Incremental decoding`.
+
+
+1. Building an Encoder and Decoder
+----------------------------------
+
+In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
+should implement the :class:`~fairseq.models.FairseqEncoder` interface and
+Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
+These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
+and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
+Modules.
+
+
+Encoder
+~~~~~~~
+
+Our Encoder will embed the tokens in the source sentence, feed them to a
+:class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
+save the following in a new file named :file:`fairseq/models/simple_lstm.py`::
+
+  import torch.nn as nn
+  from fairseq import utils
+  from fairseq.models import FairseqEncoder
+
+  class SimpleLSTMEncoder(FairseqEncoder):
+
+      def __init__(
+          self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
+      ):
+          super().__init__(dictionary)
+          self.args = args
+
+          # Our encoder will embed the inputs before feeding them to the LSTM.
+          self.embed_tokens = nn.Embedding(
+              num_embeddings=len(dictionary),
+              embedding_dim=embed_dim,
+              padding_idx=dictionary.pad(),
+          )
+          self.dropout = nn.Dropout(p=dropout)
+
+          # We'll use a single-layer, unidirectional LSTM for simplicity.
+          self.lstm = nn.LSTM(
+              input_size=embed_dim,
+              hidden_size=hidden_dim,
+              num_layers=1,
+              bidirectional=False,
+              batch_first=True,
+          )
+
+      def forward(self, src_tokens, src_lengths):
+          # The inputs to the ``forward()`` function are determined by the
+          # Task, and in particular the ``'net_input'`` key in each
+          # mini-batch. We discuss Tasks in the next tutorial, but for now just
+          # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
+          # has shape `(batch)`.
+
+          # Note that the source is typically padded on the left. This can be
+          # configured by adding the `--left-pad-source "False"` command-line
+          # argument, but here we'll make the Encoder handle either kind of
+          # padding by converting everything to be right-padded.
+          if self.args.left_pad_source:
+              # Convert left-padding to right-padding.
+              src_tokens = utils.convert_padding_direction(
+                  src_tokens,
+                  padding_idx=self.dictionary.pad(),
+                  left_to_right=True
+              )
+
+          # Embed the source.
+          x = self.embed_tokens(src_tokens)
+
+          # Apply dropout.
+          x = self.dropout(x)
+
+          # Pack the sequence into a PackedSequence object to feed to the LSTM.
+          x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
+
+          # Get the output from the LSTM.
+          _outputs, (final_hidden, _final_cell) = self.lstm(x)
+
+          # Return the Encoder's output. This can be any object and will be
+          # passed directly to the Decoder.
+          return {
+              # this will have shape `(bsz, hidden_dim)`
+              'final_hidden': final_hidden.squeeze(0),
+          }
+
+      # Encoders are required to implement this method so that we can rearrange
+      # the order of the batch elements during inference (e.g., beam search).
+      def reorder_encoder_out(self, encoder_out, new_order):
+          """
+          Reorder encoder output according to `new_order`.
+
+          Args:
+              encoder_out: output from the ``forward()`` method
+              new_order (LongTensor): desired order
+
+          Returns:
+              `encoder_out` rearranged according to `new_order`
+          """
+          final_hidden = encoder_out['final_hidden']
+          return {
+              'final_hidden': final_hidden.index_select(0, new_order),
+          }
+
+
+Decoder
+~~~~~~~
+
+Our Decoder will predict the next word, conditioned on the Encoder's final
+hidden state and an embedded representation of the previous target word -- which
+is sometimes called *teacher forcing*. More specifically, we'll use a
+:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
+to the size of the output vocabulary to predict each target word.
+
+::
+
+  import torch
+  from fairseq.models import FairseqDecoder
+
+  class SimpleLSTMDecoder(FairseqDecoder):
+
+      def __init__(
+          self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
+          dropout=0.1,
+      ):
+          super().__init__(dictionary)
+
+          # Our decoder will embed the inputs before feeding them to the LSTM.
+          self.embed_tokens = nn.Embedding(
+              num_embeddings=len(dictionary),
+              embedding_dim=embed_dim,
+              padding_idx=dictionary.pad(),
+          )
+          self.dropout = nn.Dropout(p=dropout)
+
+          # We'll use a single-layer, unidirectional LSTM for simplicity.
+          self.lstm = nn.LSTM(
+              # For the first layer we'll concatenate the Encoder's final hidden
+              # state with the embedded target tokens.
+              input_size=encoder_hidden_dim + embed_dim,
+              hidden_size=hidden_dim,
+              num_layers=1,
+              bidirectional=False,
+          )
+
+          # Define the output projection.
+          self.output_projection = nn.Linear(hidden_dim, len(dictionary))
+
+      # During training Decoders are expected to take the entire target sequence
+      # (shifted right by one position) and produce logits over the vocabulary.
+      # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
+      # ``dictionary.eos()``, followed by the target sequence.
+      def forward(self, prev_output_tokens, encoder_out):
+          """
+          Args:
+              prev_output_tokens (LongTensor): previous decoder outputs of shape
+                  `(batch, tgt_len)`, for teacher forcing
+              encoder_out (Tensor, optional): output from the encoder, used for
+                  encoder-side attention
+
+          Returns:
+              tuple:
+                  - the last decoder layer's output of shape
+                    `(batch, tgt_len, vocab)`
+                  - the last decoder layer's attention weights of shape
+                    `(batch, tgt_len, src_len)`
+          """
+          bsz, tgt_len = prev_output_tokens.size()
+
+          # Extract the final hidden state from the Encoder.
+          final_encoder_hidden = encoder_out['final_hidden']
+
+          # Embed the target sequence, which has been shifted right by one
+          # position and now starts with the end-of-sentence symbol.
+          x = self.embed_tokens(prev_output_tokens)
+
+          # Apply dropout.
+          x = self.dropout(x)
+
+          # Concatenate the Encoder's final hidden state to *every* embedded
+          # target token.
+          x = torch.cat(
+              [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
+              dim=2,
+          )
+
+          # Using PackedSequence objects in the Decoder is harder than in the
+          # Encoder, since the targets are not sorted in descending length order,
+          # which is a requirement of ``pack_padded_sequence()``. Instead we'll
+          # feed nn.LSTM directly.
+          initial_state = (
+              final_encoder_hidden.unsqueeze(0),  # hidden
+              torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
+          )
+          output, _ = self.lstm(
+              x.transpose(0, 1),  # convert to shape `(tgt_len, bsz, dim)`
+              initial_state,
+          )
+          x = output.transpose(0, 1)  # convert to shape `(bsz, tgt_len, hidden)`
+
+          # Project the outputs to the size of the vocabulary.
+          x = self.output_projection(x)
+
+          # Return the logits and ``None`` for the attention weights
+          return x, None
+
+
+2. Registering the Model
+------------------------
+
+Now that we've defined our Encoder and Decoder we must *register* our model with
+fairseq using the :func:`~fairseq.models.register_model` function decorator.
+Once the model is registered we'll be able to use it with the existing
+:ref:`Command-line Tools`.
+
+All registered models must implement the
+:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
+models (i.e., any model with a single Encoder and Decoder), we can instead
+implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
+
+Create a small wrapper class in the same file and register it in fairseq with
+the name ``'simple_lstm'``::
+
+  from fairseq.models import FairseqEncoderDecoderModel, register_model
+
+  # Note: the register_model "decorator" should immediately precede the
+  # definition of the Model class.
+
+  @register_model('simple_lstm')
+  class SimpleLSTMModel(FairseqEncoderDecoderModel):
+
+      @staticmethod
+      def add_args(parser):
+          # Models can override this method to add new command-line arguments.
+          # Here we'll add some new command-line arguments to configure dropout
+          # and the dimensionality of the embeddings and hidden states.
+          parser.add_argument(
+              '--encoder-embed-dim', type=int, metavar='N',
+              help='dimensionality of the encoder embeddings',
+          )
+          parser.add_argument(
+              '--encoder-hidden-dim', type=int, metavar='N',
+              help='dimensionality of the encoder hidden state',
+          )
+          parser.add_argument(
+              '--encoder-dropout', type=float, default=0.1,
+              help='encoder dropout probability',
+          )
+          parser.add_argument(
+              '--decoder-embed-dim', type=int, metavar='N',
+              help='dimensionality of the decoder embeddings',
+          )
+          parser.add_argument(
+              '--decoder-hidden-dim', type=int, metavar='N',
+              help='dimensionality of the decoder hidden state',
+          )
+          parser.add_argument(
+              '--decoder-dropout', type=float, default=0.1,
+              help='decoder dropout probability',
+          )
+
+      @classmethod
+      def build_model(cls, args, task):
+          # Fairseq initializes models by calling the ``build_model()``
+          # function. This provides more flexibility, since the returned model
+          # instance can be of a different type than the one that was called.
+          # In this case we'll just return a SimpleLSTMModel instance.
+
+          # Initialize our Encoder and Decoder.
+          encoder = SimpleLSTMEncoder(
+              args=args,
+              dictionary=task.source_dictionary,
+              embed_dim=args.encoder_embed_dim,
+              hidden_dim=args.encoder_hidden_dim,
+              dropout=args.encoder_dropout,
+          )
+          decoder = SimpleLSTMDecoder(
+              dictionary=task.target_dictionary,
+              encoder_hidden_dim=args.encoder_hidden_dim,
+              embed_dim=args.decoder_embed_dim,
+              hidden_dim=args.decoder_hidden_dim,
+              dropout=args.decoder_dropout,
+          )
+          model = SimpleLSTMModel(encoder, decoder)
+
+          # Print the model architecture.
+          print(model)
+
+          return model
+
+      # We could override the ``forward()`` if we wanted more control over how
+      # the encoder and decoder interact, but it's not necessary for this
+      # tutorial since we can inherit the default implementation provided by
+      # the FairseqEncoderDecoderModel base class, which looks like:
+      #
+      # def forward(self, src_tokens, src_lengths, prev_output_tokens):
+      #     encoder_out = self.encoder(src_tokens, src_lengths)
+      #     decoder_out = self.decoder(prev_output_tokens, encoder_out)
+      #     return decoder_out
+
+Finally let's define a *named architecture* with the configuration for our
+model. This is done with the :func:`~fairseq.models.register_model_architecture`
+function decorator. Thereafter this named architecture can be used with the
+``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
+
+  from fairseq.models import register_model_architecture
+
+  # The first argument to ``register_model_architecture()`` should be the name
+  # of the model we registered above (i.e., 'simple_lstm'). The function we
+  # register here should take a single argument *args* and modify it in-place
+  # to match the desired architecture.
+
+  @register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
+  def tutorial_simple_lstm(args):
+      # We use ``getattr()`` to prioritize arguments that are explicitly given
+      # on the command-line, so that the defaults defined below are only used
+      # when no other value has been specified.
+      args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
+      args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
+      args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
+      args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
+
+
+3. Training the Model
+---------------------
+
+Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
+command-line tool for this, making sure to specify our new Model architecture
+(``--arch tutorial_simple_lstm``).
+
+.. note::
+
+  Make sure you've already preprocessed the data from the IWSLT example in the
+  :file:`examples/translation/` directory.
+
+.. code-block:: console
+
+  > fairseq-train data-bin/iwslt14.tokenized.de-en \
+    --arch tutorial_simple_lstm \
+    --encoder-dropout 0.2 --decoder-dropout 0.2 \
+    --optimizer adam --lr 0.005 --lr-shrink 0.5 \
+    --max-tokens 12000
+  (...)
+  | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
+  | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
+
+The model files should appear in the :file:`checkpoints/` directory. While this
+model architecture is not very good, we can use the :ref:`fairseq-generate` script to
+generate translations and compute our BLEU score over the test set:
+
+.. code-block:: console
+
+  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
+    --path checkpoints/checkpoint_best.pt \
+    --beam 5 \
+    --remove-bpe
+  (...)
+  | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
+  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
+
+
+4. Making generation faster
+---------------------------
+
+While autoregressive generation from sequence-to-sequence models is inherently
+slow, our implementation above is especially slow because it recomputes the
+entire sequence of Decoder hidden states for every output token (i.e., it is
+``O(n^2)``). We can make this significantly faster by instead caching the
+previous hidden states.
+
+In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
+special mode at inference time where the Model only receives a single timestep
+of input corresponding to the immediately previous output token (for teacher
+forcing) and must produce the next output incrementally. Thus the model must
+cache any long-term state that is needed about the sequence, e.g., hidden
+states, convolutional states, etc.
+
+To implement incremental decoding we will modify our model to implement the
+:class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
+standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
+decoder interface allows ``forward()`` methods to take an extra keyword argument
+(*incremental_state*) that can be used to cache state across time-steps.
+
+Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
+
+  import torch
+  from fairseq.models import FairseqIncrementalDecoder
+
+  class SimpleLSTMDecoder(FairseqIncrementalDecoder):
+
+      def __init__(
+          self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
+          dropout=0.1,
+      ):
+          # This remains the same as before.
+          super().__init__(dictionary)
+          self.embed_tokens = nn.Embedding(
+              num_embeddings=len(dictionary),
+              embedding_dim=embed_dim,
+              padding_idx=dictionary.pad(),
+          )
+          self.dropout = nn.Dropout(p=dropout)
+          self.lstm = nn.LSTM(
+              input_size=encoder_hidden_dim + embed_dim,
+              hidden_size=hidden_dim,
+              num_layers=1,
+              bidirectional=False,
+          )
+          self.output_projection = nn.Linear(hidden_dim, len(dictionary))
+
+      # We now take an additional kwarg (*incremental_state*) for caching the
+      # previous hidden and cell states.
+      def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
+          if incremental_state is not None:
+              # If the *incremental_state* argument is not ``None`` then we are
+              # in incremental inference mode. While *prev_output_tokens* will
+              # still contain the entire decoded prefix, we will only use the
+              # last step and assume that the rest of the state is cached.
+              prev_output_tokens = prev_output_tokens[:, -1:]
+
+          # This remains the same as before.
+          bsz, tgt_len = prev_output_tokens.size()
+          final_encoder_hidden = encoder_out['final_hidden']
+          x = self.embed_tokens(prev_output_tokens)
+          x = self.dropout(x)
+          x = torch.cat(
+              [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
+              dim=2,
+          )
+
+          # We will now check the cache and load the cached previous hidden and
+          # cell states, if they exist, otherwise we will initialize them to
+          # zeros (as before). We will use the ``utils.get_incremental_state()``
+          # and ``utils.set_incremental_state()`` helpers.
+          initial_state = utils.get_incremental_state(
+              self, incremental_state, 'prev_state',
+          )
+          if initial_state is None:
+              # first time initialization, same as the original version
+              initial_state = (
+                  final_encoder_hidden.unsqueeze(0),  # hidden
+                  torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
+              )
+
+          # Run one step of our LSTM.
+          output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
+
+          # Update the cache with the latest hidden and cell states.
+          utils.set_incremental_state(
+              self, incremental_state, 'prev_state', latest_state,
+          )
+
+          # This remains the same as before
+          x = output.transpose(0, 1)
+          x = self.output_projection(x)
+          return x, None
+
+      # The ``FairseqIncrementalDecoder`` interface also requires implementing a
+      # ``reorder_incremental_state()`` method, which is used during beam search
+      # to select and reorder the incremental state.
+      def reorder_incremental_state(self, incremental_state, new_order):
+          # Load the cached state.
+          prev_state = utils.get_incremental_state(
+              self, incremental_state, 'prev_state',
+          )
+
+          # Reorder batches according to *new_order*.
+          reordered_state = (
+              prev_state[0].index_select(1, new_order),  # hidden
+              prev_state[1].index_select(1, new_order),  # cell
+          )
+
+          # Update the cached state.
+          utils.set_incremental_state(
+              self, incremental_state, 'prev_state', reordered_state,
+          )
+
+Finally, we can rerun generation and observe the speedup:
+
+.. code-block:: console
+
+  # Before
+
+  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
+    --path checkpoints/checkpoint_best.pt \
+    --beam 5 \
+    --remove-bpe
+  (...)
+  | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
+  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
+
+  # After
+
+  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
+    --path checkpoints/checkpoint_best.pt \
+    --beam 5 \
+    --remove-bpe
+  (...)
+  | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
+  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)

+ 2 - 0
examples/.gitignore

@@ -0,0 +1,2 @@
+!*/*.sh
+!*/*.md

+ 139 - 0
examples/MMPT/.gitignore

@@ -0,0 +1,139 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+runs
+data
+pretrained_models
+projects/mmfusion_*
+log_test
+third-party
+python_log
+slurm_snapshot_code
+lightning_logs
+demos

+ 41 - 0
examples/MMPT/CONFIG.md

@@ -0,0 +1,41 @@
+### Config Files Explained
+
+Taking `projects/mfmmlm.yaml` for example, which run pretraining using masked frame model (MFM) and masked language model (MLM) on a single BERT:  
+
+```yaml
+project_dir: mfmmlm # specify the project dir for this baseline.
+run_task:
+  - how2.yaml # run pretraining on how2 when launching `projects/taskmfmmlm.yaml`
+  - [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml] # run fine-tuning tasks.
+base_dir: task # a global template folder to specify each training task. 
+task_group:
+  pretrain: # section for pretraining. Most baselines differs in this section.
+    task_list:
+      - how2.yaml # reconfig `projects/task/how2.yaml`
+    dataset:
+      aligner: MFMMLMAligner # overwrite the aligner for MFMMLM training task.
+    model:
+      model_cls: MMFusionMFMMLM # overwrite the model, which constructs negative examples for MFM on-the-fly.
+    loss:
+      loss_cls: MFMMLM # overwrite the loss as MFMMLM, which combines MFM and MLM together.
+    fairseq: # all fairseq args can be expecified under this name.
+      dataset:
+        batch_size: 128
+  finetune: # section for fine-tuning tasks, we don't need to change anything here mostly since we want to see how pretraining can contribute to finetuning.
+    task_list: # specify the list of downstream tasks, e.g., copy `projects/task/vtt.yaml` to `projects/mfmmlm`.
+      - vtt.yaml
+      - vttqa.yaml
+      - youcook.yaml
+      - youcookcap.yaml
+      - crosstask.yaml
+      - coin.yaml
+  test: # section for testing.
+    task_list:
+      - test_vtt.yaml
+      - test_vttqa.yaml
+      - test_youcook.yaml
+      - test_youcookcap.yaml
+      - test_crosstask.yaml
+      - test_crosstask_zs.yaml
+      - test_coin.yaml
+```

+ 34 - 0
examples/MMPT/DATASET.md

@@ -0,0 +1,34 @@
+# Dataset
+
+We understand video data are challenging to download and process. For videos, we provide our preprocessing scripts under `scripts/video_feature_extractor` (deeply adapted from `https://github.com/antoine77340/video_feature_extractor`); for text, we pre-tokenizing scripts under `scripts/text_token_extractor`.
+
+### S3D Feature Extraction
+We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
+
+We implement a `PathBuilder` to automatically track video ids, source video paths to their feature locations (you may need `conda install -c anaconda pandas`). Decoding may need `pip install ffmpeg-python`.
+
+### Howto100M
+[Howto100M](https://www.di.ens.fr/willow/research/howto100m/) is a large-scale video pre-training datasets. You may download videos by yourself and run preprocessing of our scripts. 
+
+Several key differences of our preprocessing from existing papers: (1) we use `raw_caption.json` instead of `caption.json` to have pure self-supervision on text (`caption.json` has manual removal of stop words); (2) we remove partially duplicated texts that are originally designed for real-time readability (see `mmpt/processors/dedupprocessor.py`); (3) then we shard video/text features using `SharedTensor` in `mmpt/utils/shardedtensor.py` for fast loading during training (faster than `h5py`).
+
+#### Steps
+##### video
+To extract video features: edit and run `bash scripts/video_feature_extractor/how2/s3d.sh`. (consider to run this on multiple machines; by default, we store features in fp16 to save space and also for faster training).
+
+Split available video ids as `data/how2/how2_s3d_train.lst` and `data/how2/how2_s3d_val.lst`.
+
+Lastly, pack video features into `ShardedTensor` using `python scripts/video_feature_extractor/shard_feature.py`.
+
+##### text
+Clean captions using `python -m mmpt.processors.dedupprocessor`.
+
+Tokenize dedupped captions `data/how2/raw_caption_dedup.pkl` into sharded numpy arrays:  
+```
+python scripts/text_token_extractor/pretokenization.py scripts/text_token_extractor/configs/bert-base-uncased.yaml
+```
+
+### Youcook, MSRVTT etc.
+We use the version of Youcook and MSRVTT come with Howto100M and MILNCE. Please download the data to `data/youcook` and `data/msrvtt` accordingly, you can also check `projects/task/youcook.yaml` and `projects/task/vtt.yaml` etc. in details. 
+We extract features for Youcook, MSRVTT similar to the first step of Howto100M but we read text from meta data directly and perform on-the-fly tokenization.
+

+ 166 - 0
examples/MMPT/README.md

@@ -0,0 +1,166 @@
+# VideoCLIP and VLM
+
+You just find this toolkit for multimodal video understanding! It contains implementation of two recent multi-modal video understanding papers [VideoCLIP](https://arxiv.org/pdf/2109.14084.pdf) (EMNLP, 2021) and [VLM](https://aclanthology.org/2021.findings-acl.370.pdf) (ACL Findings, 2021), along with high-performance toolkits that are typically lacking in existing codebase. The toolkit is desigend to contain generic performance-tuned components that can be potentially adapted to other frameworks (we initially use fairseq). 
+
+VideoCLIP is a contrastive learning model for zero-shot transfer to retrieval/classification/sequence labeling style tasks.
+
+<img src="videoclip.png" width="350" class="center">
+
+VLM is a masked language model style pre-training using only one encoder with masked modality model (MMM) for retrieval/generation/sequence labeling style tasks.
+
+<img src="vlm.png" width="350" class="center">
+
+### News
+[Oct. 2021] Initial release of implementation for the following papers:  
+[VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding](https://arxiv.org/pdf/2109.14084.pdf) (Xu et. al., EMNLP 2021)  
+[VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding](https://aclanthology.org/2021.findings-acl.370.pdf) (Xu et. al., ACL Findings 2021)  
+
+
+### Installation
+We aim to minimize the dependency of this repo on other packages.  
+We use fairseq as the main trainer (no models/datasets dependency on fairseq. We will support other trainer in future):  
+```
+git clone https://github.com/pytorch/fairseq
+cd fairseq
+pip install -e .  # also optionally follow fairseq README for apex installation for fp16 training.
+export MKL_THREADING_LAYER=GNU  # fairseq may need this for numpy.
+```
+
+Then install this toolkit:
+```
+cd examples/MMPT  # MMPT can be in any folder, not necessarily under fairseq/examples.
+pip install -e .
+```
+
+The code is developed under Python=3.8.8, Pytorch=1.8, cuda=11.0 with fairseq=1.0.0a0+af0389f and tested under Python=3.8.8 pytorch=1.9 cuda=11.0 fairseq=1.0.0a0+8e7bc73 during code release.
+Most models require `transformers==3.4` for API compatibility `pip install transformers==3.4`. 
+In addition, some downstream tasks may need `conda install pandas`.  
+
+
+### Usage
+#### Download Checkpoints
+We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
+
+Download VideoCLIP checkpoint `https://dl.fbaipublicfiles.com/MMPT/retri/videoclip/checkpoint_best.pt` to `runs/retri/videoclip` or VLM checkpoint `https://dl.fbaipublicfiles.com/MMPT/mtm/vlm/checkpoint_best.pt` to `runs/mtm/vlm`.
+
+#### Demo of Inference
+run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` to get all `.yaml`s for VideoCLIP.
+
+```python
+import torch
+
+from mmpt.models import MMPTModel
+
+
+model, tokenizer, aligner = MMPTModel.from_pretrained(
+    "projects/retri/videoclip/how2.yaml")
+
+model.eval()
+
+
+# B, T, FPS, H, W, C (VideoCLIP is trained on 30 fps of s3d)
+video_frames = torch.randn(1, 2, 30, 224, 224, 3)
+caps, cmasks = aligner._build_text_seq(
+    tokenizer("some text", add_special_tokens=False)["input_ids"]
+)
+
+caps, cmasks = caps[None, :], cmasks[None, :]  # bsz=1
+
+with torch.no_grad():
+    output = model(video_frames, caps, cmasks, return_score=True)
+print(output["score"])  # dot-product
+```
+
+#### Data Preparation
+See [dataset](DATASET.md) for each dataset.
+
+#### Global Config for Training Pipeline
+We organize a global config file for a training/testing pipeline under projects (see a detailed [explanation](CONFIG.md)). For example, VideoCLIP in `projects/retri/videoclip.yaml` and VLM is in `projects/mtm/vlm.yaml`.
+
+We wrap all cmds into `locallaunch.py` and `mmpt_cli/localjob.py`. You can check concrete cmds by `--dryrun` and then drop it for actual run.  
+
+First, run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` will generate configs for all configs of pre-training, zero-shot evaluation, fine-tuning and testing, for VideoCLIP under `projects/retri/videoclip`.  
+
+Then each (either training or evaluation) process will be configed by a concrete config file (we save all complex arguments into the concrete config file for reproducibility, including fairseq args). For example, run zero-shot evaluation on youcook,
+```
+python locallaunch.py projects/retri/videoclip/test_youcook_zs.yaml --jobtype local_predict  # zero-shot evaluation.
+python locallaunch.py projects/retri/videoclip/youcook_videoclip.yaml --jobtype local_single --dryrun  # fine-tuning: use --dryrun to check cmds and drop it to make an actual run; local_small will run on two gpus (as in paper).
+python locallaunch.py projects/retri/videoclip/test_youcook_videoclip.yaml --jobtype local_predict  # testing on fine-tuned model.
+```
+
+Pretraining can be run as:  
+```
+python locallaunch.py projects/retri/videoclip/how2.yaml --jobtype local_single --dryrun # check then drop dryrun; paper is ran on local_big as 8 gpus.
+```
+You may need to change `--jobtype`, check/extend `LocalJob` in `mmpt_cli/localjob.py` for multi-gpu/multi-node pre-training.
+
+The detailed instructions of pretraining and fine-tuning can be found at [pretraining instruction](pretraining.md) and [finetuning instruction](endtask.md).
+
+
+### Development
+Several components of this toolkit can be re-used for future research (and also our ongoing research).
+
+#### Framework Wrapper
+We currently only support fairseq, but most components can be easily fit into other frameworks like huggingface. This repo is a `--user-dir` of fairseq with fairseq wrapper. For example, `mmpt/tasks` includes a `FairseqMMTTask`, which manages `mmpt/datasets` with `FairseqDataset`, `mmpt/models` with `FairseqModel`, `mmpt/losses` with `FairseqCriterion`.  
+
+#### Processors
+**Multi**modal research introduces the complexity on modality alignment from different input sources to losses. Inspired by [MMF](https://github.com/facebookresearch/mmf), this toolkit leverages `mmpt/processors` to handle various needs of data preprocessing and loading, **alleviating** the needs of multiple `torch.data.utils.Dataset` (that can be tricky for ablation study).  
+Processors can also be decoupled from `torch.data.utils.Dataset` for offline preprocessing instead of on-the-fly data preprocessing.
+
+We decouple a `mmpt.MMDataset` as 3 types of processors: `MetaProcessor`, `VideoProcessor`, `TextProcessor` and `Aligner`. They can be configed in `dataset` field of a config file (e.g., see `projects/task/how2.yaml`).  
+`MetaProcessor` is used to load the meta data about a dataset, aka, all video_ids of how2 dataset.  
+`VideoProcessor` is used to load the video features about a dataset. For example, S3D features for each second of a video.  
+`TextProcessor` is used to load the text (feature). For example, BERT pre-tokenized text clips for how2 dataset (with `start`s, `end`s of timestamps and `cap` for `token_ids`).  
+`Aligner` is the core class for different baselines that prepares the training data. For example, sampling a clip, masking tokens for MLM, etc.
+
+#### Performance-tuned Components
+To speed up pre-training, this toolkit uses sharded features stored in mmaped numpy, backed by `ShardedTensor` in `mmpt/utils/shardedtensor.py` (adopted from MARGE paper). This reduces the loads of IO for multi-GPU training without loading all features for a video into the memory each time and `ShardedTensor` ensure features are stored in continuous disk space for near random access. This is used for both How2 video features and texts in `mmpt/processors/how2processor.py`.
+
+
+### Citation
+If this codebase is useful for your work, please cite the following papers:
+
+```BibTeX
+@inproceedings{xu-etal-2021-videoclip,
+    title = "{VideoCLIP}: Contrastive Pre-training for\\Zero-shot Video-Text Understanding",
+    author = "Xu, Hu  and
+      Ghosh, Gargi  and
+      Huang, Po-Yao  and
+      Okhonko, Dmytro  and
+      Aghajanyan, Armen  and
+      Metze, Florian  and
+      Zettlemoyer, Luke  and
+      Feichtenhofer, Christoph",
+    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
+    month = nov,
+    year = "2021",
+    address = "Online",
+    publisher = "Association for Computational Linguistics",
+}
+
+@inproceedings{xu-etal-2021-vlm,
+    title = "{VLM}: Task-agnostic Video-Language Model Pre-training for Video Understanding",
+    author = "Xu, Hu  and
+      Ghosh, Gargi  and
+      Huang, Po-Yao  and
+      Arora, Prahal  and
+      Aminzadeh, Masoumeh  and
+      Feichtenhofer, Christoph  and
+      Metze, Florian  and
+      Zettlemoyer, Luke",
+    booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
+    month = aug,
+    year = "2021",
+    address = "Online",
+    publisher = "Association for Computational Linguistics",
+    url = "https://aclanthology.org/2021.findings-acl.370",
+    doi = "10.18653/v1/2021.findings-acl.370",
+    pages = "4227--4239",
+}
+```
+
+### Bug Reports
+This repo is in its initial stage, welcome bug reports to huxu@fb.com
+
+### Copyright
+The majority of Multimodal Pre-training (MMPT) is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Evaluation Codes/Models: Howto100M and HuggingFace Transformers are licensed under the Apache2.0 license; COIN and NLG-eval are licensed under the MIT license; CrossTask is licensed under the BSD-3; DiDeMo is licensed under the BSD-2 license.

+ 41 - 0
examples/MMPT/endtask.md

@@ -0,0 +1,41 @@
+# Zero-shot Transfer and Finetuning
+
+(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
+All finetuning datasets (specifically `processors`) are defined in `mmpt.processors.dsprocessor`.
+Given the complexity of different types of finetuning tasks, each task may have their own meta/video/text/aligner processors and `mmpt/evaluators/{Predictor,Metric}`.
+
+### Tasks
+
+Currently, we support 5 end datasets: `MSRVTT`, `Youcook`, `COIN`, `Crosstask` and `DiDeMo` with the following tasks:  
+text-video retrieval: `MSRVTT`, `Youcook`, `DiDeMo`;   
+video captioning: `Youcook`;  
+Video Question and Answering: `MSRVTT-QA`.  
+
+To add your own dataset, you can specify the corresponding processors and config them in the `dataset` field of a config file, such as `projects/task/vtt.yaml`.
+
+### Zero-shot Transfer (no Training)
+Zero-shot transfer will run the pre-trained model (e.g., VideoCLIP) directly on testing data. Configs with pattern: `projects/task/*_zs_*.yaml` are dedicated for zero-shot transfer.
+
+### Fine-tuning
+
+The training of a downstream task is similar to pretraining, execept you may need to specify the `restore_file` in `fairseq.checkpoint` and reset optimizers, see `projects/task/ft.yaml` that is included by `projects/task/vtt.yaml`.
+
+We typically do finetuning on 2 gpus (`local_small`).
+
+### Testing
+For each finetuning dataset, you may need to specify a testing config, similar to `projects/task/test_vtt.yaml`.  
+
+We define `mmpt.evaluators.Predictor` for different types of prediction. For example, `MSRVTT` and `Youcook` are video-retrieval tasks and expecting to use `RetrievalPredictor`. You may need to define your new type of predictors and specify that in `predictor` field of a testing config.
+
+Each task may also have their own metric for evaluation. This can be created in `mmpt.evaluators.Metric` and specified in the `metric` field of a testing config.
+
+Launching a testing is as simple as training by specifying the path of a testing config:
+```python locallaunch.py projects/mfmmlm/test_vtt.yaml```
+Testing will be launched locally by default since prediction is computationally less expensive.
+
+### Third-party Libraries
+We list the following finetuning tasks that require third-party libraries.
+
+Youcook captioning: `https://github.com/Maluuba/nlg-eval`  
+
+CrossTask: `https://github.com/DmZhukov/CrossTask`'s `dp` under `third-party/CrossTask` (`python setup.py build_ext --inplace`)

+ 148 - 0
examples/MMPT/locallaunch.py

@@ -0,0 +1,148 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import argparse
+import os
+
+from omegaconf import OmegaConf
+
+from mmpt.utils import recursive_config, overwrite_dir
+from mmpt_cli.localjob import LocalJob
+
+
+class JobLauncher(object):
+    JOB_CONFIG = {
+        "local": LocalJob,
+    }
+
+    def __init__(self, yaml_file):
+        self.yaml_file = yaml_file
+        job_key = "local"
+
+        if yaml_file.endswith(".yaml"):
+            config = recursive_config(yaml_file)
+            if config.task_type is not None:
+                job_key = config.task_type.split("_")[0]
+        else:
+            raise ValueError("unknown extension of job file:", yaml_file)
+        self.job_key = job_key
+
+    def __call__(self, job_type=None, dryrun=False):
+        if job_type is not None:
+            self.job_key = job_type.split("_")[0]
+        print("[JobLauncher] job_key", self.job_key)
+        job = JobLauncher.JOB_CONFIG[self.job_key](
+            self.yaml_file, job_type=job_type, dryrun=dryrun)
+        return job.submit()
+
+
+class Pipeline(object):
+    """a job that loads yaml config."""
+
+    def __init__(self, fn):
+        """
+        load a yaml config of a job and save generated configs as yaml for each task.
+        return: a list of files to run as specified by `run_task`.
+        """
+        if fn.endswith(".py"):
+            # a python command.
+            self.backend = "python"
+            self.run_yamls = [fn]
+            return
+
+        job_config = recursive_config(fn)
+        if job_config.base_dir is None:  # single file job config.
+            self.run_yamls = [fn]
+            return
+
+        self.project_dir = os.path.join("projects", job_config.project_dir)
+        self.run_dir = os.path.join("runs", job_config.project_dir)
+
+        if job_config.run_task is not None:
+            run_yamls = []
+            for stage in job_config.run_task:
+                # each stage can have multiple tasks running in parallel.
+                if OmegaConf.is_list(stage):
+                    stage_yamls = []
+                    for task_file in stage:
+                        stage_yamls.append(
+                            os.path.join(self.project_dir, task_file))
+                    run_yamls.append(stage_yamls)
+                else:
+                    run_yamls.append(os.path.join(self.project_dir, stage))
+            self.run_yamls = run_yamls
+        configs_to_save = self._overwrite_task(job_config)
+        self._save_configs(configs_to_save)
+
+    def __getitem__(self, idx):
+        yaml_files = self.run_yamls[idx]
+        if isinstance(yaml_files, list):
+            return [JobLauncher(yaml_file) for yaml_file in yaml_files]
+        return [JobLauncher(yaml_files)]
+
+    def __len__(self):
+        return len(self.run_yamls)
+
+    def _save_configs(self, configs_to_save: dict):
+        # save
+        os.makedirs(self.project_dir, exist_ok=True)
+        for config_file in configs_to_save:
+            config = configs_to_save[config_file]
+            print("saving", config_file)
+            OmegaConf.save(config=config, f=config_file)
+
+    def _overwrite_task(self, job_config):
+        configs_to_save = {}
+        self.base_project_dir = os.path.join("projects", job_config.base_dir)
+        self.base_run_dir = os.path.join("runs", job_config.base_dir)
+
+        for config_sets in job_config.task_group:
+            overwrite_config = job_config.task_group[config_sets]
+            if (
+                overwrite_config.task_list is None
+                or len(overwrite_config.task_list) == 0
+            ):
+                print(
+                    "[warning]",
+                    job_config.task_group,
+                    "has no task_list specified.")
+            # we don't want this added to a final config.
+            task_list = overwrite_config.pop("task_list", None)
+            for config_file in task_list:
+                config_file_path = os.path.join(
+                    self.base_project_dir, config_file)
+                config = recursive_config(config_file_path)
+                # overwrite it.
+                if overwrite_config:
+                    config = OmegaConf.merge(config, overwrite_config)
+                overwrite_dir(config, self.run_dir, basedir=self.base_run_dir)
+                save_file_path = os.path.join(self.project_dir, config_file)
+                configs_to_save[save_file_path] = config
+        return configs_to_save
+
+
+def main(args):
+    job_type = args.jobtype if args.jobtype else None
+    # parse multiple pipelines.
+    pipelines = [Pipeline(fn) for fn in args.yamls.split(",")]
+
+    for pipe_id, pipeline in enumerate(pipelines):
+        if not hasattr(pipeline, "project_dir"):
+            for job in pipeline[0]:
+                job(job_type=job_type, dryrun=args.dryrun)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("yamls", type=str)
+    parser.add_argument(
+        "--dryrun",
+        action="store_true",
+        help="run config and prepare to submit without launch the job.",
+    )
+    parser.add_argument(
+        "--jobtype", type=str, default="",
+        help="force to run jobs as specified.")
+    args = parser.parse_args()
+    main(args)

+ 12 - 0
examples/MMPT/mmpt/__init__.py

@@ -0,0 +1,12 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+try:
+    # fairseq user dir
+    from .datasets import FairseqMMDataset
+    from .losses import FairseqCriterion
+    from .models import FairseqMMModel
+    from .tasks import FairseqMMTask
+except ImportError:
+    pass

+ 10 - 0
examples/MMPT/mmpt/datasets/__init__.py

@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .mmdataset import *
+
+try:
+    from .fairseqmmdataset import *
+except ImportError:
+    pass

+ 57 - 0
examples/MMPT/mmpt/datasets/fairseqmmdataset.py

@@ -0,0 +1,57 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+TODO (huxu): fairseq wrapper class for all dataset you defined: mostly MMDataset.
+"""
+
+from collections import OrderedDict
+
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+from fairseq.data import FairseqDataset, data_utils
+
+
+class FairseqMMDataset(FairseqDataset):
+    """
+    A wrapper class for MMDataset for fairseq.
+    """
+
+    def __init__(self, mmdataset):
+        if not isinstance(mmdataset, Dataset):
+            raise TypeError("mmdataset must be of type `torch.utils.data.dataset`.")
+        self.mmdataset = mmdataset
+
+    def set_epoch(self, epoch, **unused):
+        super().set_epoch(epoch)
+        self.epoch = epoch
+
+    def __getitem__(self, idx):
+        with data_utils.numpy_seed(43211, self.epoch, idx):
+            return self.mmdataset[idx]
+
+    def __len__(self):
+        return len(self.mmdataset)
+
+    def collater(self, samples):
+        if hasattr(self.mmdataset, "collator"):
+            return self.mmdataset.collator(samples)
+        if len(samples) == 0:
+            return {}
+        if isinstance(samples[0], dict):
+            batch = OrderedDict()
+            for key in samples[0]:
+                if samples[0][key] is not None:
+                    batch[key] = default_collate([sample[key] for sample in samples])
+            return batch
+        else:
+            return default_collate(samples)
+
+    def size(self, index):
+        """dummy implementation: we don't use --max-tokens"""
+        return 1
+
+    def num_tokens(self, index):
+        """dummy implementation: we don't use --max-tokens"""
+        return 1

+ 111 - 0
examples/MMPT/mmpt/datasets/mmdataset.py

@@ -0,0 +1,111 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from collections import OrderedDict
+
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+from ..utils import set_seed
+
+
+class MMDataset(Dataset):
+    """
+    A generic multi-modal dataset.
+        Args:
+            `meta_processor`: a meta processor,
+                handling loading meta data and return video_id and text_id.
+            `video_processor`: a video processor,
+                handling e.g., decoding, loading .np files.
+            `text_processor`: a text processor,
+                handling e.g., tokenization.
+            `aligner`: combine the video and text feature
+                as one training example.
+    """
+
+    def __init__(
+        self,
+        meta_processor,
+        video_processor,
+        text_processor,
+        align_processor,
+    ):
+        self.split = meta_processor.split
+        self.meta_processor = meta_processor
+        self.video_processor = video_processor
+        self.text_processor = text_processor
+        self.align_processor = align_processor
+
+    def __len__(self):
+        return len(self.meta_processor)
+
+    def __getitem__(self, idx):
+        if self.split == "test":
+            set_seed(idx)
+        video_id, text_id = self.meta_processor[idx]
+        video_feature = self.video_processor(video_id)
+        text_feature = self.text_processor(text_id)
+        output = self.align_processor(video_id, video_feature, text_feature)
+        # TODO (huxu): the following is for debug purpose.
+        output.update({"idx": idx})
+        return output
+
+    def collater(self, samples):
+        """This collator is deprecated.
+        set self.collator = MMDataset.collater.
+        see collator in FairseqMMDataset.
+        """
+
+        if len(samples) == 0:
+            return {}
+        if isinstance(samples[0], dict):
+            batch = OrderedDict()
+            for key in samples[0]:
+                if samples[0][key] is not None:
+                    batch[key] = default_collate(
+                        [sample[key] for sample in samples])
+                # if torch.is_tensor(batch[key]):
+                #    print(key, batch[key].size())
+                # else:
+                #    print(key, len(batch[key]))
+            return batch
+        else:
+            return default_collate(samples)
+
+    def print_example(self, output):
+        print("[one example]", output["video_id"])
+        if (
+            hasattr(self.align_processor, "subsampling")
+            and self.align_processor.subsampling is not None
+            and self.align_processor.subsampling > 1
+        ):
+            for key in output:
+                if torch.is_tensor(output[key]):
+                    output[key] = output[key][0]
+
+        # search tokenizer to translate ids back.
+        tokenizer = None
+        if hasattr(self.text_processor, "tokenizer"):
+            tokenizer = self.text_processor.tokenizer
+        elif hasattr(self.align_processor, "tokenizer"):
+            tokenizer = self.align_processor.tokenizer
+        if tokenizer is not None:
+            caps = output["caps"].tolist()
+            if isinstance(caps[0], list):
+                caps = caps[0]
+            print("caps", tokenizer.decode(caps))
+            print("caps", tokenizer.convert_ids_to_tokens(caps))
+
+        for key, value in output.items():
+            if torch.is_tensor(value):
+                if len(value.size()) >= 3:  # attention_mask.
+                    print(key, value.size())
+                    print(key, "first", value[0, :, :])
+                    print(key, "last", value[-1, :, :])
+                else:
+                    print(key, value)
+        print("[end of one example]")

+ 13 - 0
examples/MMPT/mmpt/evaluators/__init__.py

@@ -0,0 +1,13 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .metric import *
+from .evaluator import *
+
+
+# experimental.
+try:
+    from .expmetric import *
+except ImportError:
+    pass

+ 54 - 0
examples/MMPT/mmpt/evaluators/evaluator.py

@@ -0,0 +1,54 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import glob
+import numpy as np
+
+from . import metric as metric_path
+from . import predictor as predictor_path
+
+
+class Evaluator(object):
+    """
+    perform evaluation on a single (downstream) task.
+    make this both offline and online.
+    TODO(huxu) saving evaluation results.
+    """
+
+    def __init__(self, config, eval_dataloader=None):
+        if config.metric is None:
+            raise ValueError("config.metric is", config.metric)
+        metric_cls = getattr(metric_path, config.metric)
+        self.metric = metric_cls(config)
+        if config.predictor is None:
+            raise ValueError("config.predictor is", config.predictor)
+        predictor_cls = getattr(predictor_path, config.predictor)
+        self.predictor = predictor_cls(config)
+        self.eval_dataloader = eval_dataloader
+
+    def __call__(self):
+        try:
+            print(self.predictor.pred_dir)
+            for pred_file in glob.glob(
+                    self.predictor.pred_dir + "/*_merged.npy"):
+                outputs = np.load(pred_file)
+                results = self.metric.compute_metrics(outputs)
+                self.metric.print_computed_metrics(results)
+
+            outputs = np.load(os.path.join(
+                    self.predictor.pred_dir, "merged.npy"))
+            results = self.metric.compute_metrics(outputs)
+            return {"results": results, "metric": self.metric}
+        except FileNotFoundError:
+            print("\n[missing]", self.predictor.pred_dir)
+            return {}
+
+    def evaluate(self, model, eval_dataloader=None, output_file="merged"):
+        if eval_dataloader is None:
+            eval_dataloader = self.eval_dataloader
+        outputs = self.predictor.predict_loop(
+            model, eval_dataloader, output_file)
+        results = self.metric.compute_metrics(**outputs)
+        return results

+ 313 - 0
examples/MMPT/mmpt/evaluators/metric.py

@@ -0,0 +1,313 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import json
+
+
+class Metric(object):
+    def __init__(self, config, metric_names):
+        self.metric_names = metric_names
+
+    def best_metric(self, metric):
+        return metric[self.metric_names[0]]
+
+    def save_metrics(self, fn, metrics):
+        with open(fn, "w") as fw:
+            json.dump(fw, metrics)
+
+    def print_computed_metrics(self, metrics):
+        raise NotImplementedError
+
+
+class RetrievalMetric(Metric):
+    """
+    this is modified from `howto100m/metrics.py`.
+    History of changes:
+    refactor as a class.
+    add metric_key in __init__
+    """
+
+    def __init__(self, config, metric_names=["R1", "R5", "R10", "MR"]):
+        super().__init__(config, metric_names)
+        self.error = False  # TODO(huxu): add to config to print error.
+
+    def compute_metrics(self, outputs, texts, **kwargs):
+        x = outputs
+        sx = np.sort(-x, axis=1)
+        d = np.diag(-x)
+        d = d[:, np.newaxis]
+        ind = sx - d
+        ind = np.where(ind == 0)
+        ind = ind[1]
+        metrics = {}
+        metrics["R1"] = float(np.sum(ind == 0)) / len(ind)
+        metrics["R5"] = float(np.sum(ind < 5)) / len(ind)
+        metrics["R10"] = float(np.sum(ind < 10)) / len(ind)
+        metrics["MR"] = np.median(ind) + 1
+
+        max_idx = np.argmax(outputs, axis=1)
+        if self.error:
+            # print top-20 errors.
+            error = []
+            for ex_idx in range(20):
+                error.append((texts[ex_idx], texts[max_idx[ex_idx]]))
+            metrics["error"] = error
+        return metrics
+
+    def print_computed_metrics(self, metrics):
+        r1 = metrics["R1"]
+        r5 = metrics["R5"]
+        r10 = metrics["R10"]
+        mr = metrics["MR"]
+        print(
+            "R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}".format(
+                r1, r5, r10, mr
+            )
+        )
+        if "error" in metrics:
+            print(metrics["error"])
+
+
+class DiDeMoMetric(Metric):
+    """
+    History of changes:
+    python 2.x to python 3.x.
+    merge utils.py into eval to save one file.
+    reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+    Code to evaluate your results on the DiDeMo dataset.
+    """
+    def __init__(self, config, metric_names=["rank1", "rank5", "miou"]):
+        super().__init__(config, metric_names)
+
+    def compute_metrics(self, outputs, targets, **kwargs):
+        assert len(outputs) == len(targets)
+        rank1, rank5, miou = self._eval_predictions(outputs, targets)
+        metrics = {
+            "rank1": rank1,
+            "rank5": rank5,
+            "miou": miou
+        }
+        return metrics
+
+    def print_computed_metrics(self, metrics):
+        rank1 = metrics["rank1"]
+        rank5 = metrics["rank5"]
+        miou = metrics["miou"]
+        # print("Average rank@1: %f" % rank1)
+        # print("Average rank@5: %f" % rank5)
+        # print("Average iou: %f" % miou)
+
+        print(
+            "Average rank@1: {:.4f} Average rank@5: {:.4f} Average iou: {:.4f}".format(
+                rank1, rank5, miou
+            )
+        )
+
+    def _iou(self, pred, gt):
+        intersection = max(0, min(pred[1], gt[1]) + 1 - max(pred[0], gt[0]))
+        union = max(pred[1], gt[1]) + 1 - min(pred[0], gt[0])
+        return float(intersection)/union
+
+    def _rank(self, pred, gt):
+        return pred.index(tuple(gt)) + 1
+
+    def _eval_predictions(self, segments, data):
+        '''
+        Inputs:
+        segments: For each item in the ground truth data, rank possible video segments given the description and video.
+            In DiDeMo, there are 21 posible moments extracted for each video so the list of video segments will be of length 21.
+            The first video segment should be the video segment that best corresponds to the text query.
+            There are 4180 sentence in the validation data, so when evaluating a model on the val dataset,
+            segments should be a list of lenght 4180, and each item in segments should be a list of length 21.
+        data: ground truth data
+        '''
+        average_ranks = []
+        average_iou = []
+        for s, d in zip(segments, data):
+            pred = s[0]
+            ious = [self._iou(pred, t) for t in d['times']]
+            average_iou.append(np.mean(np.sort(ious)[-3:]))
+            ranks = [self._rank(s, t) for t in d['times'] if tuple(t) in s]  # if t in s] is added for s, e not in prediction.
+            average_ranks.append(np.mean(np.sort(ranks)[:3]))
+        rank1 = np.sum(np.array(average_ranks) <= 1)/float(len(average_ranks))
+        rank5 = np.sum(np.array(average_ranks) <= 5)/float(len(average_ranks))
+        miou = np.mean(average_iou)
+
+        # print("Average rank@1: %f" % rank1)
+        # print("Average rank@5: %f" % rank5)
+        # print("Average iou: %f" % miou)
+        return rank1, rank5, miou
+
+
+class NLGMetric(Metric):
+    def __init__(
+        self,
+        config,
+        metric_names=[
+            "Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4",
+            "METEOR", "ROUGE_L", "CIDEr"
+        ]
+    ):
+        super().__init__(config, metric_names)
+        # please install NLGEval from `https://github.com/Maluuba/nlg-eval`
+        from nlgeval import NLGEval
+        self.nlg = NLGEval()
+
+    def compute_metrics(self, outputs, targets, **kwargs):
+        return self.nlg.compute_metrics(
+            hyp_list=outputs, ref_list=targets)
+
+    def print_computed_metrics(self, metrics):
+        Bleu_1 = metrics["Bleu_1"]
+        Bleu_2 = metrics["Bleu_2"]
+        Bleu_3 = metrics["Bleu_3"]
+        Bleu_4 = metrics["Bleu_4"]
+        METEOR = metrics["METEOR"]
+        ROUGE_L = metrics["ROUGE_L"]
+        CIDEr = metrics["CIDEr"]
+
+        print(
+            "Bleu_1: {:.4f} - Bleu_2: {:.4f} - Bleu_3: {:.4f} - Bleu_4: {:.4f} - METEOR: {:.4f} - ROUGE_L: {:.4f} - CIDEr: {:.4f}".format(
+                Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, CIDEr
+            )
+        )
+
+
+class QAMetric(Metric):
+    def __init__(
+        self,
+        config,
+        metric_names=["acc"]
+    ):
+        super().__init__(config, metric_names)
+
+    def compute_metrics(self, outputs, targets, **kwargs):
+        from sklearn.metrics import accuracy_score
+        return {"acc": accuracy_score(targets, outputs)}
+
+    def print_computed_metrics(self, metrics):
+        print("acc: {:.4f}".format(metrics["acc"]))
+
+
+class COINActionSegmentationMetric(Metric):
+    """
+    COIN dataset listed 3 repos for Action Segmentation.
+    Action Sets, NeuralNetwork-Viterbi, TCFPN-ISBA.
+    The first and second are the same.
+    https://github.com/alexanderrichard/action-sets/blob/master/eval.py
+
+    Future reference for the third:
+    `https://github.com/Zephyr-D/TCFPN-ISBA/blob/master/utils/metrics.py`
+    """
+    def __init__(self, config, metric_name=["frame_acc"]):
+        super().__init__(config, metric_name)
+
+    def compute_metrics(self, outputs, targets):
+        n_frames = 0
+        n_errors = 0
+        n_errors = sum(outputs != targets)
+        n_frames = len(targets)
+        return {"frame_acc": 1.0 - float(n_errors) / n_frames}
+
+    def print_computed_metrics(self, metrics):
+        fa = metrics["frame_acc"]
+        print("frame accuracy:", fa)
+
+
+class CrossTaskMetric(Metric):
+    def __init__(self, config, metric_names=["recall"]):
+        super().__init__(config, metric_names)
+
+    def compute_metrics(self, outputs, targets, **kwargs):
+        """refactored from line 166:
+        https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
+
+        recalls = self._get_recalls(Y_true=targets, Y_pred=outputs)
+        results = {}
+        for task, rec in recalls.items():
+            results[str(task)] = rec
+
+        avg_recall = np.mean(list(recalls.values()))
+        results["recall"] = avg_recall
+        return results
+
+    def print_computed_metrics(self, metrics):
+        print('Recall: {0:0.3f}'.format(metrics["recall"]))
+        for task in metrics:
+            if task != "recall":
+                print('Task {0}. Recall = {1:0.3f}'.format(
+                    task, metrics[task]))
+
+    def _get_recalls(self, Y_true, Y_pred):
+        """refactored from
+        https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
+
+        step_match = {task: 0 for task in Y_true.keys()}
+        step_total = {task: 0 for task in Y_true.keys()}
+        for task, ys_true in Y_true.items():
+            ys_pred = Y_pred[task]
+            for vid in set(ys_pred.keys()).intersection(set(ys_true.keys())):
+                y_true = ys_true[vid]
+                y_pred = ys_pred[vid]
+                step_total[task] += (y_true.sum(axis=0) > 0).sum()
+                step_match[task] += (y_true*y_pred).sum()
+        recalls = {
+            task: step_match[task] / n for task, n in step_total.items()}
+        return recalls
+
+
+class ActionRecognitionMetric(Metric):
+    def __init__(
+        self,
+        config,
+        metric_names=["acc", "acc_splits", "r1_splits", "r5_splits", "r10_splits"]
+    ):
+        super().__init__(config, metric_names)
+
+    def compute_metrics(self, outputs, targets, splits, **kwargs):
+        all_video_embd = outputs
+        labels = targets
+        split1, split2, split3 = splits
+        accs = []
+        r1s = []
+        r5s = []
+        r10s = []
+        for split in range(3):
+            if split == 0:
+                s = split1
+            elif split == 1:
+                s = split2
+            else:
+                s = split3
+
+            X_pred = all_video_embd[np.where(s == 2)[0]]
+            label_test = labels[np.where(s == 2)[0]]
+            logits = X_pred
+            X_pred = np.argmax(X_pred, axis=1)
+            acc = np.sum(X_pred == label_test) / float(len(X_pred))
+            accs.append(acc)
+            # compute recall.
+            sorted_pred = (-logits).argsort(axis=-1)
+            label_test_sp = label_test.reshape(-1, 1)
+
+            r1 = np.mean((sorted_pred[:, :1] == label_test_sp).sum(axis=1), axis=0)
+            r5 = np.mean((sorted_pred[:, :5] == label_test_sp).sum(axis=1), axis=0)
+            r10 = np.mean((sorted_pred[:, :10] == label_test_sp).sum(axis=1), axis=0)
+            r1s.append(r1)
+            r5s.append(r5)
+            r10s.append(r10)
+
+        return {"acc": accs[0], "acc_splits": accs, "r1_splits": r1s, "r5_splits": r5s, "r10_splits": r10s}
+
+    def print_computed_metrics(self, metrics):
+        for split, acc in enumerate(metrics["acc_splits"]):
+            print("Top 1 accuracy on split {}: {}; r1 {}; r5 {}; r10 {}".format(
+                split + 1, acc,
+                metrics["r1_splits"][split],
+                metrics["r5_splits"][split],
+                metrics["r10_splits"][split],
+                )
+            )

+ 595 - 0
examples/MMPT/mmpt/evaluators/predictor.py

@@ -0,0 +1,595 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import random
+import json
+import numpy as np
+import torch
+import pickle
+import math
+
+from tqdm import tqdm
+
+
+class Predictor(object):
+    """this base class is used to save predictions to disk
+        (and being called by a evaluator later).
+        Predictor has minimum support of single gpu prediction.
+    """
+    def __init__(self, config):
+        self.pred_dir = None  # on-the-fly eval does not save the results.
+        if hasattr(config, "eval") and config.eval is not None:
+            self.pred_dir = config.eval.save_path
+            os.makedirs(self.pred_dir, exist_ok=True)
+
+    def __call__(self, outputs):
+        """extract the prediction and save it."""
+        raise NotImplementedError
+
+    def predict_loop(self, model, eval_dataloader, output_file=None):
+        """on-the-fly prediction on a single gpu."""
+        self.full_scores = []
+        model.eval()
+        model = model.to(0)
+        with torch.no_grad():
+            for data in eval_dataloader:
+                data = self.to_ctx(data)
+                outputs = model(**data)
+                outputs.update(data)
+                self(outputs)
+        return self.finalize(output_file)
+
+    def finalize(self, output_file):
+        pass
+
+    def to_ctx(self, data, ctx=0, dtype=None):
+        if isinstance(data, dict):
+            for key in data:
+                if torch.is_tensor(data[key]):
+                    if dtype is not None and data[key].dtype == torch.float32:
+                        data[key] = data[key].to(dtype)
+                    data[key] = data[key].to(ctx)
+            return data
+        else:
+            raise ValueError("non-dict type of batch is not supported yet.")
+
+
+class NLGPredictor(Predictor):
+    """Predicting Text from MMFusion models."""
+    """TODO: make a context."""
+    def __init__(self, config):
+        super().__init__(config)
+        from transformers import AutoTokenizer
+
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            config.dataset.bert_name,
+            bos_token="[CLS]", eos_token="[SEP]")
+        self.bos_token_id = self.tokenizer.bos_token_id
+        self.eos_token_id = self.tokenizer.eos_token_id
+
+    def predict_loop(self, model, eval_dataloader, output_file=None):
+        """TODO: refactor base classes."""
+        ctx = 0
+        outputs = {"outputs": [], "targets": [[]]}
+        model.eval()
+        model = model.to(ctx)
+        with torch.no_grad():
+            for data in tqdm(eval_dataloader):
+                data = self.to_ctx(data, ctx)
+                self(data, model, outputs)
+        return self.finalize(outputs, output_file)
+
+    def __call__(self, data, model, outputs):
+        data.update({
+            "bos_token_id": self.bos_token_id,
+            "eos_token_id": self.eos_token_id
+        })
+
+        output = model.generate(**data)
+        assert len(output) == len(data["ref"])
+        for idx, _output in enumerate(output):
+            generated_text = self.tokenizer.decode(
+                _output, skip_special_tokens=True)
+            if generated_text == "":
+                generated_text = "none"
+            outputs["outputs"].append(generated_text)
+            outputs["targets"][0].append(data["ref"][idx])
+            if random.random() < 0.001:
+                print("_output", _output)
+                print("generated_text", generated_text)
+                print("ref", data["ref"][idx])
+
+    def finalize(self, outputs, output_file=None):
+        if output_file is not None:
+            with open(os.path.join(
+                    self.pred_dir, output_file + ".json"), "w") as fw:
+                json.dump(outputs, fw, indent=4)
+        return outputs
+
+
+class RetrievalPredictor(Predictor):
+    """generated `pooled_video` and `pooled_text`."""
+    def __init__(self, config):
+        super().__init__(config)
+        from transformers import AutoTokenizer
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            config.dataset.bert_name)
+
+    def predict_loop(
+        self,
+        model,
+        eval_dataloader,
+        output_file="retrieval.npy"
+    ):
+        """on-the-fly prediction on a single gpu."""
+        full_scores = []
+        texts = []
+        model.eval()
+        model = model.cuda()
+        with torch.no_grad():
+            for data in eval_dataloader:
+                # convert to dict.
+                if not isinstance(data, dict):
+                    data = {
+                        "caps": data[0],
+                        "cmasks": data[1],
+                        "vfeats": data[2],
+                        "vmasks": data[3],
+                        "video_id": data[4]
+                    }
+                data = self.to_ctx(data)
+                outputs = model(**data)
+                outputs.update(data)
+                self(outputs, full_scores)
+                for _cap in data["caps"]:
+                    texts.append(
+                        self.tokenizer.decode(_cap, skip_special_tokens=True)
+                    )
+
+        return self.finalize(full_scores, texts, output_file)
+
+    def __call__(self, sample, full_scores):
+        scores = self._get_pooled_outputs(sample)
+        self._append_scores(scores, full_scores)
+
+    def finalize(self, full_scores, texts, output_file=None):
+        outputs = self._aggregate_scores(full_scores)
+        if output_file is not None:
+            np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
+        return {"outputs": outputs, "texts": texts}
+
+    def _get_pooled_outputs(self, outputs):
+        if "pooled_video" in outputs:
+            return outputs["pooled_video"], outputs["pooled_text"]
+        else:
+            raise ValueError("unknown format of outputs.")
+
+    def _append_scores(self, scores, full_scores):
+        assert len(scores) == 2
+        if len(full_scores) == 0:
+            full_scores.append([])
+            full_scores.append([])
+        full_scores[0].append(scores[0].cpu().detach().numpy())
+        full_scores[1].append(scores[1].cpu().detach().numpy())
+
+    def _aggregate_scores(self, scores):
+        assert len(scores) == 2
+        video_hidden = np.concatenate(scores[0], axis=0)
+        text_hidden = np.concatenate(scores[1], axis=0)
+        # clear up.
+        self.full_scores = []
+        return np.matmul(text_hidden, video_hidden.T)
+
+
+class QAPredictor(Predictor):
+    """generated `pooled_video` and `pooled_text`."""
+    def __init__(self, config):
+        super().__init__(config)
+        """predictor maintains scores and aggregate them."""
+
+    def predict_loop(self, model, eval_dataloader, output_file="qa.npy"):
+        """on-the-fly prediction on a single gpu."""
+        self.full_scores = []
+        model.eval()
+        model = model.cuda()
+        with torch.no_grad():
+            for data in eval_dataloader:
+                # reshape ans and dup video 5 times.
+                v_len = data["vfeats"].size(1)
+                hidden_size = data["vfeats"].size(2)
+                data["vfeats"] = data["vfeats"].unsqueeze(1).repeat(1, 5, 1, 1).view(-1, v_len, hidden_size)
+                data["vmasks"] = data["vmasks"].unsqueeze(1).repeat(1, 5, 1).view(-1, v_len)
+
+                t_len = data["caps"].size(-1)
+                data["caps"] = data["caps"].view(-1, t_len)
+                data["cmasks"] = data["cmasks"].view(-1, t_len)
+
+                data = self.to_ctx(data)
+                outputs = model(**data)
+                outputs.update(data)
+                self(outputs)
+        return self.finalize(output_file)
+
+    def __call__(self, sample):
+        hidden_size = sample["pooled_video"].size(-1)
+        pooled_video = sample["pooled_video"].view(-1, 5, hidden_size)
+        pooled_text = sample["pooled_text"].view(-1, 5, hidden_size)
+        scores = torch.bmm(pooled_video, pooled_text.transpose(2, 1))
+        scores = scores.argmax(-1)
+        self._append_scores(scores[:, 0], sample["answers"], self.full_scores)
+
+    def finalize(self, output_file=None):
+        outputs, targets = self._aggregate_scores(self.full_scores)
+        if output_file is not None:
+            np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
+        return {"outputs": outputs, "targets": targets}
+
+    def _append_scores(self, scores, answers, full_scores):
+        if len(full_scores) == 0:
+            full_scores.append([])
+            full_scores.append([])
+        full_scores[0].append(scores.cpu().detach().numpy())
+        full_scores[1].append(answers.cpu().detach().numpy())
+
+    def _aggregate_scores(self, scores):
+        assert len(scores) == 2
+        outputs = np.concatenate(scores[0], axis=0)
+        targets = np.concatenate(scores[1], axis=0)
+        # clear up.
+        self.full_scores = []
+        return outputs, targets
+
+
+class CrossTaskPredictor(Predictor):
+    """
+    CrossTaskPredictor needs to compute the average of logits
+    for overlapped sliding-window.
+    """
+    def __init__(self, config):
+        super().__init__(config)
+        self.lsm = torch.nn.LogSoftmax(dim=1)
+        self.max_video_len = config.dataset.max_video_len
+        self.sliding_window = config.dataset.sliding_window
+        self.sliding_window_size = config.dataset.sliding_window_size
+        self.annotation_path = config.dataset.annotation_path
+
+    def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
+        """refactored from line 144:
+        https://github.com/DmZhukov/CrossTask/blob/master/train.py
+        """
+        ctx = 0
+        model.eval()
+        model = model.to(ctx)
+        # this is not a loss but just compute neg_log_prob.
+        Y_pred = {}
+        Y_true = {}
+        with torch.no_grad():
+            for batch in eval_dataloader:
+                self(batch, model, Y_pred, Y_true)
+        return self.finalize(Y_pred, Y_true, output_file)
+
+    def __call__(self, sample, model, Y_pred, Y_true):
+        # please install dp from `https://github.com/DmZhukov/CrossTask`
+        from dp import dp
+        vid, task = sample['video_id'][0], sample['task'][0]
+        sample = self.to_ctx(sample)
+        # compute the average logits over sliding windows.
+        output = model(**sample)
+        batch_logits = output["logits"].cpu()
+
+        video_len = sample["video_len"][0]
+
+        # the following version is slow.
+        logits = torch.zeros((video_len, batch_logits.size(1)))
+        logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
+        # use the same loop as aligner to recover.
+        batch_logit_idx = 0
+        for window_start in range(0, video_len, self.sliding_window):
+            video_end = min(video_len - window_start, self.sliding_window_size)
+            logits[window_start: window_start + video_end] += batch_logits[
+                batch_logit_idx: batch_logit_idx + video_end]
+            batch_logit_idx += video_end
+            logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
+
+            if (video_len - window_start) <= self.sliding_window_size:
+                break
+
+        logits /= logits_counts
+        assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
+
+        O = self.lsm(logits)
+        y = np.zeros(O.size(), dtype=np.float32)
+        dp(y, -O.detach().cpu().numpy())
+        if task not in Y_pred:
+            Y_pred[task] = {}
+        Y_pred[task][vid] = y
+        annot_path = os.path.join(
+            self.annotation_path, task+'_'+vid+'.csv')
+        if os.path.exists(annot_path):
+            if task not in Y_true:
+                Y_true[task] = {}
+            Y_true[task][vid] = self._read_assignment(
+                *y.shape, annot_path)
+
+    def finalize(self, Y_pred, Y_true, output_file=None):
+        if output_file is not None:
+            with open(
+                    os.path.join(self.pred_dir, output_file + ".pkl"),
+                    "wb") as fw:
+                pickle.dump(
+                    {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
+                    protocol=pickle.HIGHEST_PROTOCOL)
+        return {"outputs": Y_pred, "targets": Y_true}
+
+    def _read_assignment(self, T, K, path):
+        """
+        refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py
+        Howto interpret contraints on loss that is going to be minimized:
+        lambd is a big number;
+        self.lambd * C is a big number for all valid position (csv stores invalids)
+
+        def forward(self, O, Y, C):
+            return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum()
+
+        This will load the csv file and fill-in the step col from start to end rows.
+        """
+
+        Y = np.zeros([T, K], dtype=np.uint8)
+        with open(path, 'r') as f:
+            for line in f:
+                step, start, end = line.strip().split(',')
+                start = int(math.floor(float(start)))
+                end = int(math.ceil(float(end)))
+                step = int(step) - 1
+                Y[start:end, step] = 1
+        return Y
+
+
+class COINPredictor(Predictor):
+    """
+    COINPredictor is similar to CrossTask on sliding windows.
+    """
+    def __init__(self, config):
+        super().__init__(config)
+        self.max_video_len = config.dataset.max_video_len
+        self.sliding_window = config.dataset.sliding_window
+        self.sliding_window_size = config.dataset.sliding_window_size
+
+    def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
+        """refactored from line 144:
+        https://github.com/DmZhukov/CrossTask/blob/master/train.py
+        """
+        ctx = 0
+        model.eval()
+        model = model.to(ctx)
+        # this is not a loss but just compute neg_log_prob.
+        Y_pred = []
+        Y_true = []
+        with torch.no_grad():
+            for batch in eval_dataloader:
+                self(batch, model, Y_pred, Y_true)
+        return self.finalize(Y_pred, Y_true, output_file)
+
+    def __call__(self, sample, model, Y_pred, Y_true):
+        sample = self.to_ctx(sample)
+        # compute the average logits over sliding windows.
+        output = model(**sample)
+        logits = self._merge_windows(sample, output)
+        Y_pred.append(logits.argmax(dim=1))
+        Y_true.append(sample["video_targets"].squeeze(0).cpu())
+
+    def _merge_windows(self, sample, output):
+        targets = sample["targets"].reshape(-1).cpu()
+        valid_mask = targets != -100
+        targets = targets[valid_mask]
+        batch_logits = output["logits"].cpu()
+        batch_logits = batch_logits.reshape(-1, batch_logits.size(-1))
+        batch_logits = batch_logits[valid_mask]
+
+        video_len = sample["video_len"][0]
+
+        # the following version is slow.
+        logits = torch.zeros((video_len, batch_logits.size(1)))
+        logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
+        # use the same loop as aligner to recover.
+        batch_logit_idx = 0
+        for window_start in range(0, video_len, self.sliding_window):
+            video_end = min(video_len - window_start, self.sliding_window_size)
+            logits[window_start: window_start + video_end] += batch_logits[
+                batch_logit_idx: batch_logit_idx + video_end]
+            batch_logit_idx += video_end
+            logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
+            if (video_len - window_start) <= self.sliding_window_size:
+                break
+        logits /= logits_counts
+        assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
+        return logits
+
+    def finalize(self, Y_pred, Y_true, output_file=None):
+        Y_pred = torch.cat(Y_pred, dim=0).numpy()
+        Y_true = torch.cat(Y_true, dim=0).numpy()
+        assert len(Y_pred) == len(Y_true)
+
+        error_mask = Y_pred != Y_true
+        print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
+        print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
+
+        if output_file is not None:
+            with open(
+                    os.path.join(self.pred_dir, output_file + ".pkl"),
+                    "wb") as fw:
+                pickle.dump(
+                    {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
+                    protocol=pickle.HIGHEST_PROTOCOL)
+        return {"outputs": Y_pred, "targets": Y_true}
+
+
+class COINZSPredictor(COINPredictor):
+    """
+    COINZSPredictor for COIN zero-shot prediction.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.dataset_config = config.dataset
+
+    def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
+        """refactored from line 144:
+        https://github.com/DmZhukov/CrossTask/blob/master/train.py
+        """
+        ctx = 0
+        model.eval()
+        model = model.to(ctx)
+
+        with torch.no_grad():
+            outputs = eval_dataloader.dataset.meta_processor.meta_text_labels(
+                self.dataset_config)
+            outputs = self.to_ctx(outputs, ctx)
+            label_hidden_states = model.forward_text(**outputs).cpu()
+            label_sim = label_hidden_states @ label_hidden_states.t()
+            num_labels = label_sim.size(0)
+            eye_mask = ~torch.eye(num_labels, dtype=torch.bool)
+            label_sim = label_sim.masked_select(eye_mask).view(num_labels, num_labels - 1)
+            lbd = label_sim.max()
+
+        # this is not a loss but just compute neg_log_prob.
+        Y_pred = []
+        Y_true = []
+        with torch.no_grad():
+            for batch in eval_dataloader:
+                self(batch, label_hidden_states, model, lbd, Y_pred, Y_true)
+        return self.finalize(Y_pred, Y_true, output_file)
+
+    def reshape_subsample(self, sample):
+        for key in sample:
+            if torch.is_tensor(sample[key]):
+                sample[key] = self.flat_subsample(sample[key])
+        return sample
+
+    def flat_subsample(self, tensor):
+        if len(tensor.size()) > 1 and tensor.size(0) == 1:
+            tensor = tensor.squeeze(0)
+        return tensor
+
+    def __call__(self, sample, label_hidden_states, model, lbd, Y_pred, Y_true):
+        sample = self.reshape_subsample(sample)
+        sample = self.to_ctx(sample)
+        # compute the average logits over sliding windows.
+        sample["output_hidden_states"] = True
+        video_outputs = model.forward_video(**sample).cpu()
+        output = {"logits": video_outputs[:, 1:sample["vmasks"].size(1)+1] @ label_hidden_states.t()}
+        logits = self._merge_windows(sample, output)
+        # logic of zero-shot for sequence labeling.
+        logits_argmax = logits.argmax(dim=1) + 1  # 0 is "O" label.
+        logits_max = logits.max(dim=1)[0]
+
+        pred = torch.zeros_like(logits_argmax)
+        label_select = logits_max > lbd  # 73 or 74
+        pred[label_select] = logits_argmax[label_select]
+
+        Y_pred.append(pred)
+        Y_true.append(sample["video_targets"].squeeze(0).cpu())
+
+    def finalize(self, Y_pred, Y_true, output_file=None):
+        Y_pred = torch.cat(Y_pred, dim=0).numpy()
+        Y_true = torch.cat(Y_true, dim=0).numpy()
+        assert len(Y_pred) == len(Y_true)
+
+        error_mask = Y_pred != Y_true
+        print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
+        print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
+
+        if output_file is not None:
+            with open(
+                    os.path.join(self.pred_dir, output_file + ".pkl"),
+                    "wb") as fw:
+                pickle.dump(
+                    {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
+                    protocol=pickle.HIGHEST_PROTOCOL)
+        return {"outputs": Y_pred, "targets": Y_true}
+
+
+class DiDeMoPredictor(Predictor):
+    """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+    https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
+    """
+    def __init__(self, config):
+        super().__init__(config)
+        # load targets.
+        with open(config.dataset.test_path) as data_file:
+            self.test_data = json.load(data_file)
+
+    def predict_loop(self, model, eval_dataloader, output_file="didemo.npy"):
+        """
+        TODO: two solutions here.
+        """
+        import itertools
+        # 21 chunks.
+        self.possible_segments = [(0,0), (1,1), (2,2), (3,3), (4,4), (5,5)]
+        for i in itertools.combinations(range(6), 2):
+            self.possible_segments.append(i)
+        # pick segments from a video.
+
+        """on-the-fly prediction on a single gpu."""
+        self.full_scores = []
+        model.eval()
+        model = model.cuda()
+        with torch.no_grad():
+            for data in eval_dataloader:
+                # TODO special forwarding logic here.
+                data = self.to_ctx(data)
+                data["output_hidden_states"] = True
+                hidden_video = model.forward_video(**data)
+                data["output_hidden_states"] = False
+                pooled_text = model.forward_text(**data)
+                outputs = {
+                    "hidden_video": hidden_video,
+                    "pooled_text": pooled_text
+                }
+                outputs.update(data)
+                self(outputs)
+        return self.finalize(output_file)
+
+    def __call__(self, sample):
+        # TODO: make an index select from self.possible_segments.
+        hidden_video = sample["hidden_video"]
+        pooled_text = sample["pooled_text"]
+        vmasks = sample["vmasks"]
+        # probably maintain valid results here.
+
+        hidden_video = hidden_video[:, 1:-1, :]
+        # probably maintain valid results here.
+        pooled_video = []
+        for s, e in self.possible_segments:
+            pooled_video.append(
+                torch.mean(
+                    hidden_video[:, int(s*5):int((e+1)*5), :],
+                    dim=1, keepdim=True)
+            )
+        pooled_video = torch.cat(pooled_video, dim=1)
+        scores = torch.bmm(
+            pooled_video, pooled_text.unsqueeze(-1)).squeeze(-1).cpu()
+
+        ranks = scores.argsort(dim=-1, descending=True)
+
+        for batch_idx, rank in enumerate(ranks):
+            rank_of_moment = []
+            for m_idx, moment in enumerate(rank):
+                s, e = self.possible_segments[moment.item()]
+                if torch.any(
+                    vmasks[batch_idx, int(s*5):int((e+1)*5)]
+                ):
+                    rank_of_moment.append((s, e))
+            self.full_scores.append(rank_of_moment)
+
+    def finalize(self, output_file=None):
+        outputs = self._aggregate_scores(self.full_scores)
+        if output_file is not None:
+            np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
+        return {"outputs": outputs, "targets": self.test_data}
+
+    def _aggregate_scores(self, scores):
+        self.full_scores = []
+        return scores

+ 16 - 0
examples/MMPT/mmpt/losses/__init__.py

@@ -0,0 +1,16 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .loss import *
+from .nce import *
+
+try:
+    from .fairseqmmloss import *
+except ImportError:
+    pass
+
+try:
+    from .expnce import *
+except ImportError:
+    pass

+ 63 - 0
examples/MMPT/mmpt/losses/fairseqmmloss.py

@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+TODO (huxu): a general fairseq criterion for all your pre-defined losses.
+"""
+
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.logging import metrics
+
+
+@register_criterion("mmloss")
+class MMCriterion(FairseqCriterion):
+    def __init__(self, task):
+        super().__init__(task)
+        # TODO (huxu): wrap forward call of loss_fn and eval_fn into task.
+        self.mmtask = task.mmtask
+
+    def forward(self, model, sample):
+        """Compute the loss for the given sample.
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        outputs = self.mmtask(model, sample)
+
+        loss, loss_scalar, max_len, batch_size, sample_size = (
+            outputs["loss"],
+            outputs["loss_scalar"],
+            outputs["max_len"],
+            outputs["batch_size"],
+            outputs["sample_size"],
+        )
+
+        logging_output = {
+            "loss": loss_scalar,
+            "ntokens": max_len * batch_size,  # dummy report.
+            "nsentences": batch_size,  # dummy report.
+            "sample_size": sample_size,
+        }
+
+        return loss, 1, logging_output
+
+    @staticmethod
+    def reduce_metrics(logging_outputs) -> None:
+        """Aggregate logging outputs from data parallel training."""
+        """since we use NCE, our actual batch_size is 1 per GPU.
+        Then we take the mean of each worker."""
+        loss_sum = sum(log.get("loss", 0.0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        metrics.log_scalar("loss", loss_sum / sample_size, round=3)
+
+    @staticmethod
+    def logging_outputs_can_be_summed() -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return True

+ 87 - 0
examples/MMPT/mmpt/losses/loss.py

@@ -0,0 +1,87 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch
+
+from torch import nn
+
+
+class Loss(object):
+    def __call__(self, *args, **kwargs):
+        raise NotImplementedError
+
+
+# Dummy Loss for testing.
+class DummyLoss(Loss):
+    def __init__(self):
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, logits, targets, **kwargs):
+        return self.loss(logits, targets)
+
+
+class DummyK400Loss(Loss):
+    """dummy k400 loss for MViT."""
+    def __init__(self):
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, logits, targets, **kwargs):
+        return self.loss(
+            logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))
+
+
+class CrossEntropy(Loss):
+    def __init__(self):
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, logits, targets, **kwargs):
+        return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
+
+
+class ArgmaxCrossEntropy(Loss):
+    def __init__(self):
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, logits, targets, **kwargs):
+        return self.loss(logits, targets.argmax(dim=1))
+
+
+class BCE(Loss):
+    def __init__(self):
+        self.loss = nn.BCEWithLogitsLoss()
+
+    def __call__(self, logits, targets, **kwargs):
+        targets = targets.squeeze(0)
+        return self.loss(logits, targets)
+
+
+class NLGLoss(Loss):
+    def __init__(self):
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, logits, text_label, **kwargs):
+        targets = text_label[text_label != -100]
+        return self.loss(logits, targets)
+
+
+class MSE(Loss):
+    def __init__(self):
+        self.loss = nn.MSELoss()
+
+    def __call__(self, logits, targets, **kwargs):
+        return self.loss(logits, targets)
+
+
+class L1(Loss):
+    def __init__(self):
+        self.loss = nn.L1Loss()
+
+    def __call__(self, logits, targets, **kwargs):
+        return self.loss(logits, targets)
+
+
+class SmoothL1(Loss):
+    def __init__(self):
+        self.loss = nn.SmoothL1Loss()
+
+    def __call__(self, logits, targets, **kwargs):
+        return self.loss(logits, targets)

+ 156 - 0
examples/MMPT/mmpt/losses/nce.py

@@ -0,0 +1,156 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+softmax-based NCE loss, used by this project.
+"""
+
+import torch
+
+from torch import nn
+
+from .loss import Loss
+
+
+class NCE(Loss):
+    def __init__(self):
+        # TODO (huxu): define temperature.
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, align_scores, **kargs):
+        # note: we reuse the same shape as cls head in BERT (batch_size, 2)
+        # but NCE only needs one logits.
+        # (so we drop all weights in the second neg logits.)
+        align_scores = align_scores[:, :1]
+        # duplicate negative examples
+        batch_size = align_scores.size(0) // 2
+        pos_scores = align_scores[:batch_size]
+        neg_scores = align_scores[batch_size:].view(1, batch_size).repeat(
+            batch_size, 1)
+        scores = torch.cat([pos_scores, neg_scores], dim=1)
+        return self.loss(
+            scores,
+            torch.zeros(
+                (batch_size,),
+                dtype=torch.long,
+                device=align_scores.device),
+        )
+
+
+class T2VContraLoss(Loss):
+    """NCE for MM joint space, on softmax text2video matrix.
+    """
+    def __init__(self):
+        # TODO (huxu): define temperature.
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, pooled_video, pooled_text, **kargs):
+        batch_size = pooled_video.size(0)
+        logits = torch.mm(pooled_text, pooled_video.transpose(1, 0))
+        targets = torch.arange(
+            batch_size,
+            dtype=torch.long,
+            device=pooled_video.device)
+        return self.loss(logits, targets)
+
+
+class V2TContraLoss(Loss):
+    """NCE for MM joint space, with softmax on video2text matrix."""
+
+    def __init__(self):
+        # TODO (huxu): define temperature.
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, pooled_video, pooled_text, **kargs):
+        batch_size = pooled_video.size(0)
+        logits = torch.mm(pooled_video, pooled_text.transpose(1, 0))
+        targets = torch.arange(
+            batch_size,
+            dtype=torch.long,
+            device=pooled_video.device)
+        return self.loss(logits, targets)
+
+
+class MMContraLoss(Loss):
+    def __init__(self):
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(self, pooled_video, pooled_text, **kwargs):
+        logits_per_video = pooled_video @ pooled_text.t()
+        logits_per_text = pooled_text @ pooled_video.t()
+
+        targets = torch.arange(
+            pooled_video.size(0),
+            dtype=torch.long,
+            device=pooled_video.device)
+        loss_video = self.loss(logits_per_video, targets)
+        loss_text = self.loss(logits_per_text, targets)
+        return loss_video + loss_text
+
+
+class MTM(Loss):
+    """Combination of MFM and MLM."""
+
+    def __init__(self):
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(
+        self,
+        video_logits,
+        text_logits,
+        video_label,
+        text_label,
+        **kwargs
+    ):
+        text_logits = torch.cat([
+            text_logits,
+            torch.zeros(
+                (text_logits.size(0), 1), device=text_logits.device)
+        ], dim=1)
+        vt_logits = torch.cat([video_logits, text_logits], dim=0)
+        # loss for video.
+        video_label = torch.zeros(
+            (video_logits.size(0),),
+            dtype=torch.long,
+            device=video_logits.device
+        )
+
+        # loss for text.
+        text_label = text_label.reshape(-1)
+        labels_mask = text_label != -100
+        selected_text_label = text_label[labels_mask]
+
+        vt_label = torch.cat([video_label, selected_text_label], dim=0)
+        return self.loss(vt_logits, vt_label)
+
+
+class MFMMLM(Loss):
+    """Combination of MFM and MLM."""
+
+    def __init__(self):
+        self.loss = nn.CrossEntropyLoss()
+
+    def __call__(
+        self,
+        video_logits,
+        text_logits,
+        video_label,
+        text_label,
+        **kwargs
+    ):
+        # loss for video.
+        video_label = torch.zeros(
+            (video_logits.size(0),),
+            dtype=torch.long,
+            device=video_logits.device
+        )
+        masked_frame_loss = self.loss(video_logits, video_label)
+
+        # loss for text.
+        text_label = text_label.reshape(-1)
+        labels_mask = text_label != -100
+        selected_text_label = text_label[labels_mask]
+        masked_lm_loss = self.loss(text_logits, selected_text_label)
+        return masked_frame_loss + masked_lm_loss

+ 17 - 0
examples/MMPT/mmpt/models/__init__.py

@@ -0,0 +1,17 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .mmfusion import *
+from .transformermodel import *
+from .mmfusionnlg import *
+
+try:
+    from .fairseqmmmodel import *
+except ImportError:
+    pass
+
+try:
+    from .expmmfusion import *
+except ImportError:
+    pass

+ 51 - 0
examples/MMPT/mmpt/models/fairseqmmmodel.py

@@ -0,0 +1,51 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.models import (
+    BaseFairseqModel,
+    register_model,
+    register_model_architecture
+)
+
+
+@register_model("mmmodel")
+class FairseqMMModel(BaseFairseqModel):
+    """a fairseq wrapper of model built by `task`."""
+
+    @classmethod
+    def build_model(cls, args, task):
+        return FairseqMMModel(task.mmtask.model)
+
+    def __init__(self, mmmodel):
+        super().__init__()
+        self.mmmodel = mmmodel
+
+    def forward(self, *args, **kwargs):
+        return self.mmmodel(*args, **kwargs)
+
+    def upgrade_state_dict_named(self, state_dict, name):
+
+        super().upgrade_state_dict_named(state_dict, name)
+
+        keys_to_delete = []
+
+        for key in state_dict:
+            if key not in self.state_dict():
+                keys_to_delete.append(key)
+        for key in keys_to_delete:
+            print("[INFO]", key, "not used anymore.")
+            del state_dict[key]
+
+        # copy any newly defined parameters.
+        for key in self.state_dict():
+            if key not in state_dict:
+                print("[INFO] adding", key)
+                state_dict[key] = self.state_dict()[key]
+
+
+# a dummy arch, we config the model.
+@register_model_architecture("mmmodel", "mmarch")
+def mmarch(args):
+    pass

+ 926 - 0
examples/MMPT/mmpt/models/mmfusion.py

@@ -0,0 +1,926 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+
+import torch
+
+from torch import nn
+
+try:
+    from transformers import AutoConfig, AutoTokenizer
+except ImportError:
+    pass
+
+from . import transformermodel
+
+
+class MMPTModel(nn.Module):
+    """An e2e wrapper of inference model.
+    """
+    @classmethod
+    def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"):
+        import os
+        from ..utils import recursive_config
+        from ..tasks import Task
+        config = recursive_config(config)
+        mmtask = Task.config_task(config)
+        checkpoint_path = os.path.join(config.eval.save_path, checkpoint)
+        mmtask.build_model(checkpoint=checkpoint_path)
+        # TODO(huxu): make the video encoder configurable.
+        from ..processors.models.s3dg import S3D
+        video_encoder = S3D('pretrained_models/s3d_dict.npy', 512)
+        video_encoder.load_state_dict(
+            torch.load('pretrained_models/s3d_howto100m.pth'))
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(
+            config.dataset.bert_name, use_fast=config.dataset.use_fast
+        )
+        from ..processors import Aligner
+        aligner = Aligner(config.dataset)
+        return (
+            MMPTModel(config, mmtask.model, video_encoder),
+            tokenizer,
+            aligner
+        )
+
+    def __init__(self, config, model, video_encoder, **kwargs):
+        super().__init__()
+        self.max_video_len = config.dataset.max_video_len
+        self.video_encoder = video_encoder
+        self.model = model
+
+    def forward(self, video_frames, caps, cmasks, return_score=False):
+        bsz = video_frames.size(0)
+        assert bsz == 1, "only bsz=1 is supported now."
+        seq_len = video_frames.size(1)
+        video_frames = video_frames.view(-1, *video_frames.size()[2:])
+        vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3))
+        vfeats = vfeats['video_embedding']
+        vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1))
+        padding = torch.zeros(
+            bsz, self.max_video_len - seq_len, vfeats.size(-1))
+        vfeats = torch.cat([vfeats, padding], dim=1)
+        vmasks = torch.cat([
+            torch.ones((bsz, seq_len), dtype=torch.bool),
+            torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool)
+            ],
+            dim=1
+        )
+        output = self.model(caps, cmasks, vfeats, vmasks)
+        if return_score:
+            output = {"score": torch.bmm(
+                output["pooled_video"][:, None, :],
+                output["pooled_text"][:, :, None]
+            ).squeeze(-1).squeeze(-1)}
+        return output
+
+
+class MMFusion(nn.Module):
+    """a MMPT wrapper class for MMBert style models.
+    TODO: move isolated mask to a subclass.
+    """
+    def __init__(self, config, **kwargs):
+        super().__init__()
+        transformer_config = AutoConfig.from_pretrained(
+            config.dataset.bert_name)
+        self.hidden_size = transformer_config.hidden_size
+        self.is_train = False
+        if config.dataset.train_path is not None:
+            self.is_train = True
+        # 0 means no iso; 1-12 means iso up to that layer.
+        self.num_hidden_layers = transformer_config.num_hidden_layers
+        self.last_iso_layer = 0
+        if config.dataset.num_iso_layer is not None:
+            self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1
+
+        if config.model.mm_encoder_cls is not None:
+            mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls)
+            model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
+            model_config.max_video_len = config.dataset.max_video_len
+            # TODO: a general way to add parameter for a model.
+            model_config.use_seg_emb = config.model.use_seg_emb
+            self.mm_encoder = mm_encoder_cls.from_pretrained(
+                config.dataset.bert_name, config=model_config)
+        elif config.model.video_encoder_cls is not None\
+                and config.model.text_encoder_cls is not None:
+            video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls)
+            model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
+            model_config.max_video_len = config.dataset.max_video_len
+            # TODO: make each model a set of config class.
+            if hasattr(model_config, "num_layers"):
+                model_config.num_layers = config.model.num_hidden_video_layers
+            else:
+                model_config.num_hidden_layers = config.model.num_hidden_video_layers
+            self.video_encoder = video_encoder_cls.from_pretrained(
+                config.dataset.bert_name, config=model_config)
+            # exact same NLP model from Huggingface.
+            text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls)
+            self.text_encoder = text_encoder_cls.from_pretrained(
+                config.dataset.bert_name)
+        else:
+            raise ValueError("the encoder must be either MM or two backbones.")
+
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        **kwargs
+    ):
+        raise NotImplementedError(
+            "Please derive MMFusion module."
+        )
+
+    def _mm_on_the_fly(
+        self,
+        cmasks,
+        vmasks,
+        attention_mask
+    ):
+        """helper function for mask, seg_ids and token_type_ids."""
+        if attention_mask is None:
+            attention_mask = self._mm_attention_mask(cmasks, vmasks)
+
+        """
+        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+        | first sequence    | second sequence |
+        """
+        token_type_ids = torch.cat(
+            [
+                torch.zeros(
+                    (vmasks.size(0), vmasks.size(1) + 2),
+                    dtype=torch.long,
+                    device=vmasks.device,
+                ),
+                torch.ones(
+                    (cmasks.size(0), cmasks.size(1) - 2),
+                    dtype=torch.long,
+                    device=cmasks.device,
+                ),
+            ],
+            dim=1,
+        )
+        return attention_mask, token_type_ids
+
+    def _mm_attention_mask(self, cmasks, vmasks):
+        assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format(
+            str(cmasks.size()),
+            str(vmasks.size()),
+            str(cmasks.size(0)),
+            str(vmasks.size(0)),
+        )
+
+        mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1)
+        if self.last_iso_layer == 0:
+            # hard attention mask.
+            return mm_mask
+        else:
+            # a gpu iso mask; 0 : num_iso_layer is isolated;
+            # num_iso_layer: are MM-fused.
+            # make an iso layer
+            batch_size = cmasks.size(0)
+            iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks)
+            mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1)
+            iso_mm_masks = []
+            # hard attention mask.
+            iso_mask = iso_mask[:, None, :, :].repeat(
+                1, self.last_iso_layer, 1, 1)
+            iso_mm_masks.append(iso_mask)
+            if self.last_iso_layer < self.num_hidden_layers:
+                mm_mask = mm_mask[:, None, :, :].repeat(
+                    1, self.num_hidden_layers - self.last_iso_layer, 1, 1
+                )
+                iso_mm_masks.append(mm_mask)
+            iso_mm_masks = torch.cat(iso_mm_masks, dim=1)
+            return iso_mm_masks
+
+    def _make_iso_mask(self, batch_size, cmasks, vmasks):
+        cls_self_mask = torch.cat(
+            [
+                torch.ones(
+                    (batch_size, 1), dtype=torch.bool, device=cmasks.device),
+                torch.zeros(
+                    (batch_size, cmasks.size(1) + vmasks.size(1) - 1),
+                    dtype=torch.bool, device=cmasks.device)
+            ], dim=1)
+
+        iso_video_mask = torch.cat(
+            [
+                # [CLS] is not used.
+                torch.zeros(
+                    (batch_size, 1), dtype=torch.bool, device=cmasks.device
+                ),
+                vmasks,
+                # assume to be 1.
+                cmasks[:, 1:2],
+                # 2 means [CLS] + [SEP]
+                torch.zeros(
+                    (batch_size, cmasks.size(1) - 2),
+                    dtype=torch.bool,
+                    device=cmasks.device,
+                ),
+            ],
+            dim=1,
+        )
+        iso_text_mask = torch.cat(
+            [
+                torch.zeros(
+                    (batch_size, 2 + vmasks.size(1)),
+                    dtype=torch.bool,
+                    device=cmasks.device,
+                ),  # [CLS] is not used.
+                cmasks[:, 2:],  # assume to be 1.
+            ],
+            dim=1,
+        )
+        cls_self_mask = cls_self_mask[:, None, :]
+        iso_video_mask = iso_video_mask[:, None, :].repeat(
+            1, vmasks.size(1) + 1, 1)
+        iso_text_mask = iso_text_mask[:, None, :].repeat(
+            1, cmasks.size(1) - 2, 1)
+        return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1)
+
+    def _pooling_vt_layer(
+        self,
+        layered_sequence_output,
+        cmasks,
+        vmasks
+    ):
+        layer_idx = self.last_iso_layer \
+                if self.last_iso_layer > 0 else self.num_hidden_layers
+        hidden_state = layered_sequence_output[layer_idx]
+        # also output pooled_video and pooled_text.
+        batch_size = cmasks.size(0)
+        # pool the modality.
+        text_offset = vmasks.size(1) + 2  # [CLS] + [SEP]
+        # video tokens + [SEP]
+        video_outputs = hidden_state[:, 1:text_offset]
+        video_attention_mask = torch.cat(
+            [
+                vmasks,
+                torch.ones(
+                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+            ],
+            dim=1,
+        )
+        assert video_outputs.size(1) == video_attention_mask.size(1)
+        pooled_video = torch.sum(
+            video_outputs * video_attention_mask.unsqueeze(-1), dim=1
+        ) / video_attention_mask.sum(1, keepdim=True)
+        # pooled_video = torch.mean(video_outputs[0], dim=1)
+
+        # text tokens + [SEP]
+        text_attention_mask = cmasks[:, 2:]
+        text_outputs = hidden_state[:, text_offset:]
+        assert text_outputs.size(1) == text_attention_mask.size(1)
+        pooled_text = torch.sum(
+            text_outputs * text_attention_mask.unsqueeze(-1), dim=1
+        ) / text_attention_mask.sum(1, keepdim=True)
+        return pooled_video, pooled_text
+
+
+class MMFusionMFMMLM(MMFusion):
+    """forward function for MFM and MLM."""
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        attention_mask=None,
+        video_label=None,
+        text_label=None,
+        **kwargs
+    ):
+        output_hidden_states = False if self.is_train else True
+
+        target_vfeats, non_masked_frame_mask = None, None
+        if video_label is not None:
+            target_vfeats = vfeats.masked_select(
+                video_label.unsqueeze(-1)).view(
+                -1, vfeats.size(-1)
+            )
+            # mask video token.
+            vfeats[video_label] = 0.0
+            non_masked_frame_mask = vmasks.clone()
+            non_masked_frame_mask[video_label] = False
+
+        attention_mask, token_type_ids = self._mm_on_the_fly(
+            cmasks, vmasks, attention_mask)
+
+        outputs = self.mm_encoder(
+            input_ids=caps,
+            input_video_embeds=vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            masked_frame_labels=video_label,
+            target_video_hidden_states=target_vfeats,
+            non_masked_frame_mask=non_masked_frame_mask,
+            masked_lm_labels=text_label,
+            output_hidden_states=output_hidden_states,
+        )
+
+        video_logits, text_logits = outputs[0], outputs[1]
+
+        if self.is_train:  # return earlier for training.
+            return {
+                "video_logits": video_logits,
+                "text_logits": text_logits,
+            }
+
+        pooled_video, pooled_text = self._pooling_vt_layer(
+            outputs[2], cmasks, vmasks)
+        return {"pooled_video": pooled_video, "pooled_text": pooled_text}
+
+
+class MMFusionMTM(MMFusionMFMMLM):
+    def __init__(self, config, **kwargs):
+        super().__init__(config)
+        """
+        For reproducibility:
+        self.mm_encoder will be initialized then discarded.
+        """
+        from .transformermodel import MMBertForMTM
+        model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
+        model_config.max_video_len = config.dataset.max_video_len
+        model_config.use_seg_emb = config.model.use_seg_emb
+        self.mm_encoder = MMBertForMTM.from_pretrained(
+            config.dataset.bert_name, config=model_config)
+
+
+class MMFusionShare(MMFusion):
+    """A retrival wrapper using mm_encoder as both video/text backbone.
+    TODO: move formally.
+    """
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        attention_mask=None,
+        video_label=None,
+        text_label=None,
+        output_hidden_states=False,
+        **kwargs
+    ):
+        pooled_video = self.forward_video(
+            vfeats,
+            vmasks,
+            caps,
+            cmasks,
+            output_hidden_states
+        )
+
+        pooled_text = self.forward_text(
+            caps,
+            cmasks,
+            output_hidden_states
+        )
+
+        return {"pooled_video": pooled_video, "pooled_text": pooled_text}
+
+    def forward_video(
+        self,
+        vfeats,
+        vmasks,
+        caps,
+        cmasks,
+        output_hidden_states=False,
+        **kwargs
+    ):
+        input_ids = caps[:, :2]
+
+        attention_mask = torch.cat([
+            cmasks[:, :1],
+            vmasks,
+            cmasks[:, 1:2]
+        ], dim=1)
+
+        token_type_ids = torch.zeros(
+            (vmasks.size(0), vmasks.size(1) + 2),
+            dtype=torch.long,
+            device=vmasks.device)
+
+        outputs = self.mm_encoder(
+            input_ids=input_ids,
+            input_video_embeds=vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            output_hidden_states=True
+        )
+        video_outputs = outputs[0]
+
+        if output_hidden_states:
+            return video_outputs
+
+        batch_size = cmasks.size(0)
+
+        video_attention_mask = torch.cat(
+            [
+                torch.zeros(
+                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+                vmasks,
+                torch.ones(
+                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+            ],
+            dim=1,
+        )
+        assert video_outputs.size(1) == video_attention_mask.size(1)
+
+        video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
+            / video_attention_mask.sum(1, keepdim=True)
+
+        pooled_video = torch.bmm(
+            video_outputs.transpose(2, 1),
+            video_attention_mask.unsqueeze(2)
+        ).squeeze(-1)
+        return pooled_video  # video_outputs
+
+    def forward_text(
+        self,
+        caps,
+        cmasks,
+        output_hidden_states=False,
+        **kwargs
+    ):
+        input_ids = torch.cat([
+            caps[:, :1], caps[:, 2:],
+            ], dim=1)
+
+        attention_mask = torch.cat([
+            cmasks[:, :1],
+            cmasks[:, 2:]
+        ], dim=1)
+
+        token_type_ids = torch.cat([
+            torch.zeros(
+                (cmasks.size(0), 1),
+                dtype=torch.long,
+                device=cmasks.device),
+            torch.ones(
+                (cmasks.size(0), cmasks.size(1) - 2),
+                dtype=torch.long,
+                device=cmasks.device)
+            ], dim=1)
+
+        outputs = self.mm_encoder(
+            input_ids=input_ids,
+            input_video_embeds=None,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            output_hidden_states=True
+        )
+        text_outputs = outputs[0]
+
+        if output_hidden_states:
+            return text_outputs
+
+        batch_size = caps.size(0)
+        # text tokens + [SEP]
+        text_attention_mask = torch.cat([
+            torch.zeros(
+                (batch_size, 1), dtype=torch.bool, device=cmasks.device),
+            cmasks[:, 2:]
+        ], dim=1)
+
+        assert text_outputs.size(1) == text_attention_mask.size(1)
+
+        text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
+            / text_attention_mask.sum(1, keepdim=True)
+
+        pooled_text = torch.bmm(
+            text_outputs.transpose(2, 1),
+            text_attention_mask.unsqueeze(2)
+        ).squeeze(-1)
+        return pooled_text  # text_outputs
+
+
+class MMFusionSeparate(MMFusionShare):
+    def forward_video(
+        self,
+        vfeats,
+        vmasks,
+        caps,
+        cmasks,
+        output_hidden_states=False,
+        **kwargs
+    ):
+        input_ids = caps[:, :2]
+
+        attention_mask = torch.cat([
+            cmasks[:, :1],
+            vmasks,
+            cmasks[:, 1:2]
+        ], dim=1)
+
+        token_type_ids = torch.zeros(
+            (vmasks.size(0), vmasks.size(1) + 2),
+            dtype=torch.long,
+            device=vmasks.device)
+
+        outputs = self.video_encoder(
+            input_ids=input_ids,
+            input_video_embeds=vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            output_hidden_states=True
+        )
+        video_outputs = outputs[0]
+
+        if output_hidden_states:
+            return video_outputs
+
+        batch_size = cmasks.size(0)
+
+        video_attention_mask = torch.cat(
+            [
+                torch.zeros(
+                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+                vmasks,
+                torch.ones(
+                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
+            ],
+            dim=1,
+        )
+        assert video_outputs.size(1) == video_attention_mask.size(1)
+
+        video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
+            / video_attention_mask.sum(1, keepdim=True)
+
+        pooled_video = torch.bmm(
+            video_outputs.transpose(2, 1),
+            video_attention_mask.unsqueeze(2)
+        ).squeeze(-1)
+        return pooled_video  # video_outputs
+
+    def forward_text(
+        self,
+        caps,
+        cmasks,
+        output_hidden_states=False,
+        **kwargs
+    ):
+        input_ids = torch.cat([
+            caps[:, :1], caps[:, 2:],
+            ], dim=1)
+
+        attention_mask = torch.cat([
+            cmasks[:, :1],
+            cmasks[:, 2:]
+        ], dim=1)
+        # different from sharing, we use all-0 type.
+        token_type_ids = torch.zeros(
+            (cmasks.size(0), cmasks.size(1) - 1),
+            dtype=torch.long,
+            device=cmasks.device)
+
+        outputs = self.text_encoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            output_hidden_states=True
+        )
+        text_outputs = outputs[0]
+
+        if output_hidden_states:
+            return text_outputs
+
+        batch_size = caps.size(0)
+        # text tokens + [SEP]
+        text_attention_mask = torch.cat([
+            torch.zeros(
+                (batch_size, 1), dtype=torch.bool, device=cmasks.device),
+            cmasks[:, 2:]
+        ], dim=1)
+
+        assert text_outputs.size(1) == text_attention_mask.size(1)
+
+        text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
+            / text_attention_mask.sum(1, keepdim=True)
+
+        pooled_text = torch.bmm(
+            text_outputs.transpose(2, 1),
+            text_attention_mask.unsqueeze(2)
+        ).squeeze(-1)
+        return pooled_text  # text_outputs
+
+
+class MMFusionJoint(MMFusion):
+    """fine-tuning wrapper for retrival task."""
+
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        attention_mask=None,
+        video_label=None,
+        text_label=None,
+        **kwargs
+    ):
+        # TODO (huxu): other ways to do negative examples; move the following
+        # into your criterion forward.
+        output_hidden_states = True
+
+        attention_mask, token_type_ids = self._mm_on_the_fly(
+            cmasks, vmasks, attention_mask)
+
+        separate_forward_split = (
+            None if self.is_train else vmasks.size(1) + 2
+        )  # [CLS] + [SEP]
+
+        outputs = self.mm_encoder(
+            input_ids=caps,
+            input_video_embeds=vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            output_hidden_states=output_hidden_states,
+            separate_forward_split=separate_forward_split,
+        )
+
+        pooled_video, pooled_text = self._pooling_vt_layer(
+            outputs[2], cmasks, vmasks)
+        return {"pooled_video": pooled_video, "pooled_text": pooled_text}
+
+
+class MMFusionActionSegmentation(MMFusion):
+    """Fine-tuning wrapper for action segmentation.
+    TODO: rename this for VLM.
+    """
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        attention_mask=None,
+        **kwargs
+    ):
+        # ActionLocalization assume of batch_size=1, squeeze it.
+        caps = caps.view(-1, caps.size(-1))
+        cmasks = cmasks.view(-1, cmasks.size(-1))
+        vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
+        vmasks = vmasks.view(-1, vmasks.size(-1))
+
+        # this may not cover all shapes of attention_mask.
+        attention_mask = attention_mask.view(
+            -1, attention_mask.size(2), attention_mask.size(3)) \
+            if attention_mask is not None else None
+
+        # TODO (huxu): other ways to do negative examples; move the following
+        # into your criterion forward.
+        output_hidden_states = True
+
+        #  video forwarding, text is dummy; never use attention_mask.
+        attention_mask, token_type_ids = self._mm_on_the_fly(
+            cmasks, vmasks, attention_mask)
+
+        logits = self.mm_encoder(
+            input_ids=caps,
+            input_video_embeds=vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            output_hidden_states=output_hidden_states,
+        )
+        return {"logits": logits[0][:, 1:vmasks.size(1)+1]}
+
+
+class MMFusionActionLocalization(MMFusion):
+    """fine-tuning model for retrival task."""
+
+    def __init__(self, config, **kwargs):
+        super().__init__(config)
+        tokenizer = AutoTokenizer.from_pretrained(
+            config.dataset.bert_name)
+        self.cls_token_id = tokenizer.cls_token_id
+        self.sep_token_id = tokenizer.sep_token_id
+        self.pad_token_id = tokenizer.pad_token_id
+
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        attention_mask=None,
+        **kwargs
+    ):
+        # ActionLocalization assume of batch_size=1, squeeze it.
+        caps = caps.squeeze(0)
+        cmasks = cmasks.squeeze(0)
+        vfeats = vfeats.squeeze(0)
+        vmasks = vmasks.squeeze(0)
+        attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None
+
+        # TODO (huxu): other ways to do negative examples; move the following
+        # into your criterion forward.
+        output_hidden_states = True
+
+        # a len1 dummy video token.
+        dummy_vfeats = torch.zeros(
+            (caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype)
+        dummy_vmasks = torch.ones(
+            (caps.size(0), 1), dtype=torch.bool,
+            device=vfeats.device)
+
+        dummy_caps = torch.LongTensor(
+            [[self.cls_token_id, self.sep_token_id,
+              self.pad_token_id, self.sep_token_id]],
+            ).to(caps.device).repeat(vfeats.size(0), 1)
+        dummy_cmasks = torch.BoolTensor(
+            [[0, 1, 0, 1]]  # pad are valid for attention.
+            ).to(caps.device).repeat(vfeats.size(0), 1)
+
+        #  video forwarding, text is dummy; never use attention_mask.
+        attention_mask, token_type_ids = self._mm_on_the_fly(
+            dummy_cmasks, vmasks, None)
+
+        outputs = self.mm_encoder(
+            input_ids=dummy_caps,
+            input_video_embeds=vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            output_hidden_states=output_hidden_states,
+        )
+
+        layer_idx = self.last_iso_layer \
+                if self.last_iso_layer > 0 else self.num_hidden_layers
+
+        video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select(
+                vmasks.unsqueeze(-1)
+            ).view(-1, self.hidden_size)
+
+        # text forwarding, video is dummy
+        attention_mask, token_type_ids = self._mm_on_the_fly(
+            cmasks, dummy_vmasks, None)
+
+        outputs = self.mm_encoder(
+            input_ids=caps,
+            input_video_embeds=dummy_vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            output_hidden_states=output_hidden_states,
+        )
+
+        _, pooled_text = self._pooling_vt_layer(
+            outputs[2], cmasks, dummy_vmasks)
+        # this line is not right.
+        logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
+        return {"logits": logits}
+
+
+# --------------- MMFusionSeparate for end tasks ---------------
+
+class MMFusionSeparateActionSegmentation(MMFusionSeparate):
+    """Fine-tuning wrapper for action segmentation."""
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        attention_mask=None,
+        **kwargs
+    ):
+        # ActionLocalization assume of batch_size=1, squeeze it.
+        caps = caps.view(-1, caps.size(-1))
+        cmasks = cmasks.view(-1, cmasks.size(-1))
+        vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
+        vmasks = vmasks.view(-1, vmasks.size(-1))
+        logits = self.forward_video(
+            vfeats,
+            vmasks,
+            caps,
+            cmasks,
+            output_hidden_states=True
+        )
+        return {"logits": logits[:, 1:vmasks.size(1)+1]}
+
+
+class MMFusionSeparateActionLocalization(MMFusionSeparate):
+    def __init__(self, config, **kwargs):
+        super().__init__(config)
+        tokenizer = AutoTokenizer.from_pretrained(
+            config.dataset.bert_name)
+        self.cls_token_id = tokenizer.cls_token_id
+        self.sep_token_id = tokenizer.sep_token_id
+        self.pad_token_id = tokenizer.pad_token_id
+
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        **kwargs
+    ):
+        # ActionLocalization assume of batch_size=1, squeeze it.
+        caps = caps.squeeze(0)
+        cmasks = cmasks.squeeze(0)
+        vfeats = vfeats.squeeze(0)
+        vmasks = vmasks.squeeze(0)
+
+        # TODO (huxu): other ways to do negative examples; move the following
+        # into your criterion forward.
+        dummy_caps = torch.LongTensor(
+            [[self.cls_token_id, self.sep_token_id,
+              self.pad_token_id, self.sep_token_id]],
+            ).to(caps.device).repeat(vfeats.size(0), 1)
+        dummy_cmasks = torch.BoolTensor(
+            [[0, 1, 0, 1]]  # pad are valid for attention.
+            ).to(caps.device).repeat(vfeats.size(0), 1)
+
+        outputs = self.forward_video(
+            vfeats,
+            vmasks,
+            dummy_caps,
+            dummy_cmasks,
+            output_hidden_states=True
+        )
+
+        video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
+                vmasks.unsqueeze(-1)
+            ).view(-1, self.hidden_size)
+
+        pooled_text = self.forward_text(
+            caps,
+            cmasks,
+            output_hidden_states=False
+        )
+
+        # this line is not right.
+        logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
+        return {"logits": logits}
+
+
+class MMFusionShareActionLocalization(MMFusionShare):
+    def __init__(self, config, **kwargs):
+        super().__init__(config)
+        tokenizer = AutoTokenizer.from_pretrained(
+            config.dataset.bert_name)
+        self.cls_token_id = tokenizer.cls_token_id
+        self.sep_token_id = tokenizer.sep_token_id
+        self.pad_token_id = tokenizer.pad_token_id
+
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        **kwargs
+    ):
+        # ActionLocalization assume of batch_size=1, squeeze it.
+        caps = caps.squeeze(0)
+        cmasks = cmasks.squeeze(0)
+        vfeats = vfeats.squeeze(0)
+        vmasks = vmasks.squeeze(0)
+
+        # TODO (huxu): other ways to do negative examples; move the following
+        # into your criterion forward.
+        dummy_caps = torch.LongTensor(
+            [[self.cls_token_id, self.sep_token_id,
+              self.pad_token_id, self.sep_token_id]],
+            ).to(caps.device).repeat(vfeats.size(0), 1)
+        dummy_cmasks = torch.BoolTensor(
+            [[0, 1, 0, 1]]  # pad are valid for attention.
+            ).to(caps.device).repeat(vfeats.size(0), 1)
+
+        outputs = self.forward_video(
+            vfeats,
+            vmasks,
+            dummy_caps,
+            dummy_cmasks,
+            output_hidden_states=True
+        )
+
+        video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
+                vmasks.unsqueeze(-1)
+            ).view(-1, self.hidden_size)
+
+        pooled_text = self.forward_text(
+            caps,
+            cmasks,
+            output_hidden_states=False
+        )
+
+        # this line is not right.
+        logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
+        return {"logits": logits}

+ 999 - 0
examples/MMPT/mmpt/models/mmfusionnlg.py

@@ -0,0 +1,999 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+
+import torch
+
+from torch.nn import functional as F
+
+from typing import Optional, Iterable
+
+try:
+    from transformers import BertPreTrainedModel
+    from transformers.modeling_bert import BertOnlyMLMHead
+
+    from transformers.file_utils import ModelOutput
+    from transformers.modeling_outputs import CausalLMOutput
+    from transformers.generation_utils import (
+        BeamHypotheses,
+        top_k_top_p_filtering
+    )
+except ImportError:
+    pass
+
+from .mmfusion import MMFusion
+from .transformermodel import MMBertModel
+from ..modules import VideoTokenMLP
+
+
+class MMFusionNLG(MMFusion):
+    def __init__(self, config, **kwargs):
+        super().__init__(config)
+        if config.model.max_decode_length is not None:
+            self.max_length = min(
+                config.model.max_decode_length,
+                config.dataset.max_len - config.dataset.max_video_len - 3
+            )
+        else:
+            self.max_length = \
+                config.dataset.max_len - config.dataset.max_video_len - 3
+        self.gen_param = config.gen_param if config.gen_param is not None \
+            else {}
+
+    def forward(
+        self,
+        caps,
+        cmasks,
+        vfeats,
+        vmasks,
+        attention_mask,
+        video_label=None,
+        text_label=None,
+        **kwargs
+    ):
+        """use pre-trained LM header for generation."""
+        attention_mask, token_type_ids = self._mm_on_the_fly(
+            cmasks, vmasks, attention_mask)
+
+        outputs = self.mm_encoder(
+            input_ids=caps,
+            input_video_embeds=vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            masked_lm_labels=text_label,
+        )
+        return {"logits": outputs[0]}
+
+    @torch.no_grad()
+    def generate(
+        self,
+        caps, cmasks, vfeats, vmasks,
+        attention_mask=None,
+        bos_token_id=None,
+        eos_token_id=None,
+        **kwargs
+    ):
+        # a simplified interface from
+        # https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html#GenerationMixin.generate
+
+        # caps now only have
+        # [CLS], [SEP] (for video) and [CLS] (as bos_token)
+        assert caps.size(1) == 3
+
+        attention_mask, token_type_ids = self._mm_on_the_fly(
+            cmasks, vmasks, attention_mask)
+
+        output = self.mm_encoder.generate(
+            input_ids=caps,
+            input_video_embeds=vfeats,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            max_length=self.max_length,
+            **self.gen_param
+        )
+        return output
+
+
+class MMBertForNLG(BertPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.bert = MMBertModel(config)
+        self.videomlp = VideoTokenMLP(config)
+        # we do not use `BertGenerationOnlyLMHead`
+        # because we can reuse pretraining.
+        self.cls = BertOnlyMLMHead(config)
+        self.hidden_size = config.hidden_size
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def forward(
+        self,
+        input_ids=None,
+        input_video_embeds=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        masked_lm_labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        # similar to MMBertForMFMMLM without MFM.
+        video_tokens = self.videomlp(input_video_embeds)
+        outputs = self.bert(
+            input_ids,
+            video_tokens,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = outputs[0]
+
+        prediction_scores = None
+        if masked_lm_labels is not None:
+            text_offset = input_video_embeds.size(1) + 1  # [CLS]
+            # recover caps format: [CLS] [SEP] text [SEP]
+            text_sequence_output = torch.cat(
+                [sequence_output[:, :1], sequence_output[:, text_offset:]],
+                dim=1
+            )
+
+            # only compute select tokens to training to speed up.
+            hidden_size = text_sequence_output.size(-1)
+            # masked_lm_labels = masked_lm_labels.reshape(-1)
+            labels_mask = masked_lm_labels != -100
+
+            selected_text_output = text_sequence_output.masked_select(
+                labels_mask.unsqueeze(-1)
+            ).view(-1, hidden_size)
+            prediction_scores = self.cls(selected_text_output)
+
+        if not return_dict:
+            output = (
+                prediction_scores,
+            ) + outputs[2:]
+            return output
+
+        # for generation.
+        text_offset = input_video_embeds.size(1) + 2  # [CLS]
+        text_sequence_output = sequence_output[:, text_offset:]
+        prediction_scores = self.cls(text_sequence_output)
+        return CausalLMOutput(
+            loss=None,
+            logits=prediction_scores,
+        )
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        input_video_embeds,
+        attention_mask=None,
+        token_type_ids=None,
+        **model_kwargs
+    ):
+        # must return a dictionary.
+        seq_len = input_ids.size(1) + input_video_embeds.size(1)
+        if attention_mask is not None:
+            if len(attention_mask.size()) == 4:
+                attention_mask = attention_mask[:, :, :seq_len, :seq_len]
+            elif len(attention_mask.size()) == 3:
+                attention_mask = attention_mask[:, :seq_len, :seq_len]
+            else:
+                attention_mask = attention_mask[:, :seq_len]
+        if token_type_ids is not None:
+            token_type_ids = token_type_ids[:, :seq_len]
+
+        return {
+            "input_ids": input_ids,
+            "input_video_embeds": input_video_embeds,
+            "attention_mask": attention_mask,
+            "token_type_ids": token_type_ids,
+        }
+
+    @torch.no_grad()
+    def generate(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        max_length: Optional[int] = None,
+        min_length: Optional[int] = None,
+        do_sample: Optional[bool] = None,
+        early_stopping: Optional[bool] = None,
+        num_beams: Optional[int] = None,
+        temperature: Optional[float] = None,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        repetition_penalty: Optional[float] = None,
+        bad_words_ids: Optional[Iterable[int]] = None,
+        bos_token_id: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        length_penalty: Optional[float] = None,
+        no_repeat_ngram_size: Optional[int] = None,
+        num_return_sequences: Optional[int] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        decoder_start_token_id: Optional[int] = None,
+        use_cache: Optional[bool] = None,
+        **model_kwargs
+    ) -> torch.LongTensor:
+        r"""
+        Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
+        beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
+        Adapted in part from `Facebook's XLM beam search code
+        <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
+        Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
+        attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
+        indicated are the default values of those config.
+        Most of these parameters are explained in more detail in `this blog post
+        <https://huggingface.co/blog/how-to-generate>`__.
+        Parameters:
+            input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+                The sequence used as a prompt for the generation. If :obj:`None` the method initializes
+                it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
+            decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+                initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only
+                decoder_start_token_id is passed as the first token to the decoder.
+            max_length (:obj:`int`, `optional`, defaults to 20):
+                The maximum length of the sequence to be generated.
+            min_length (:obj:`int`, `optional`, defaults to 10):
+                The minimum length of the sequence to be generated.
+            do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
+                Whether or not to use sampling ; use greedy decoding otherwise.
+            early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
+                Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
+            num_beams (:obj:`int`, `optional`, defaults to 1):
+                Number of beams for beam search. 1 means no beam search.
+            temperature (:obj:`float`, `optional`, defaults tp 1.0):
+                The value used to module the next token probabilities.
+            top_k (:obj:`int`, `optional`, defaults to 50):
+                The number of highest probability vocabulary tokens to keep for top-k-filtering.
+            top_p (:obj:`float`, `optional`, defaults to 1.0):
+                If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or
+                higher are kept for generation.
+            repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
+                The parameter for repetition penalty. 1.0 means no penalty. See `this paper
+                <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
+            pad_token_id (:obj:`int`, `optional`):
+                The id of the `padding` token.
+            bos_token_id (:obj:`int`, `optional`):
+                The id of the `beginning-of-sequence` token.
+            eos_token_id (:obj:`int`, `optional`):
+                The id of the `end-of-sequence` token.
+            length_penalty (:obj:`float`, `optional`, defaults to 1.0):
+                Exponential penalty to the length. 1.0 means no penalty.
+                Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
+                order to encourage the model to produce longer sequences.
+            no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
+                If set to int > 0, all ngrams of that size can only occur once.
+            bad_words_ids(:obj:`List[int]`, `optional`):
+                List of token ids that are not allowed to be generated. In order to get the tokens of the words that
+                should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
+            num_return_sequences(:obj:`int`, `optional`, defaults to 1):
+                The number of independently computed returned sequences for each element in the batch.
+            attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+                Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
+                tokens that are not masked, and 0 for masked tokens.
+                If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token.
+                `What are attention masks? <../glossary.html#attention-mask>`__
+            decoder_start_token_id (:obj:`int`, `optional`):
+                If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
+            use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
+                Whether or not the model should use the past last key/values attentions (if applicable to the model) to
+                speed up decoding.
+            model_kwargs:
+                Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
+        Return:
+            :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
+            The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
+            shorter if all batches finished early due to the :obj:`eos_token_id`.
+        Examples::
+            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
+            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
+            outputs = model.generate(max_length=40)  # do greedy decoding
+            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
+            tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
+            model = AutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
+            input_context = 'The dog'
+            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
+            outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
+            for i in range(3): #  3 output sequences were generated
+                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
+            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
+            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
+            input_context = 'The dog'
+            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
+            outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True)  # generate 3 candidates using sampling
+            for i in range(3): #  3 output sequences were generated
+                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
+            tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
+            model = AutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
+            input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
+            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
+            outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
+            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
+            tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
+            model = AutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
+            input_context = 'My cute dog'  # "Legal" is one of the control codes for ctrl
+            bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
+            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
+            outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
+        """
+
+        # We cannot generate if the model does not have a LM head
+        if self.get_output_embeddings() is None:
+            raise AttributeError(
+                "You tried to generate sequences with a model that does not have a LM Head."
+                "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
+            )
+
+        max_length = max_length if max_length is not None else self.config.max_length
+        min_length = min_length if min_length is not None else self.config.min_length
+        do_sample = do_sample if do_sample is not None else self.config.do_sample
+        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        num_beams = num_beams if num_beams is not None else self.config.num_beams
+        temperature = temperature if temperature is not None else self.config.temperature
+        top_k = top_k if top_k is not None else self.config.top_k
+        top_p = top_p if top_p is not None else self.config.top_p
+        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
+        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
+        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
+        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
+        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
+        no_repeat_ngram_size = (
+            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
+        )
+        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
+        num_return_sequences = (
+            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
+        )
+        decoder_start_token_id = (
+            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
+        )
+
+        if input_ids is not None:
+            batch_size = input_ids.shape[0]  # overriden by the input batch_size
+        else:
+            batch_size = 1
+
+        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
+        assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
+        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
+        assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
+        assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
+        assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
+        assert temperature > 0, "`temperature` should be strictly positive."
+        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
+        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
+        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
+        assert input_ids is not None or (
+            isinstance(bos_token_id, int) and bos_token_id >= 0
+        ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
+        assert pad_token_id is None or (
+            isinstance(pad_token_id, int) and (pad_token_id >= 0)
+        ), "`pad_token_id` should be a positive integer."
+        assert (eos_token_id is None) or (
+            isinstance(eos_token_id, int) and (eos_token_id >= 0)
+        ), "`eos_token_id` should be a positive integer."
+        assert length_penalty > 0, "`length_penalty` should be strictly positive."
+        assert (
+            isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
+        ), "`no_repeat_ngram_size` should be a positive integer."
+        assert (
+            isinstance(num_return_sequences, int) and num_return_sequences > 0
+        ), "`num_return_sequences` should be a strictly positive integer."
+        assert (
+            bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
+        ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
+
+        if input_ids is None:
+            assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
+                "you should either supply a context to complete as `input_ids` input "
+                "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
+            )
+            input_ids = torch.full(
+                (batch_size, 1),
+                bos_token_id,
+                dtype=torch.long,
+                device=next(self.parameters()).device,
+            )
+        else:
+            assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
+
+        # not allow to duplicate outputs when greedy decoding
+        if do_sample is False:
+            if num_beams == 1:
+                # no_beam_search greedy generation conditions
+                assert (
+                    num_return_sequences == 1
+                ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
+
+            else:
+                # beam_search greedy generation conditions
+                assert (
+                    num_beams >= num_return_sequences
+                ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
+
+        # create attention mask if necessary
+        # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
+        if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
+            attention_mask = input_ids.ne(pad_token_id).long()
+        elif attention_mask is None:
+            attention_mask = input_ids.new_ones(input_ids.shape)
+
+        # set pad_token_id to eos_token_id if not set. Important that this is done after
+        # attention_mask is created
+        if pad_token_id is None and eos_token_id is not None:
+            print(
+                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
+            )
+            pad_token_id = eos_token_id
+
+        # vocab size
+        if hasattr(self.config, "vocab_size"):
+            vocab_size = self.config.vocab_size
+        elif (
+            self.config.is_encoder_decoder
+            and hasattr(self.config, "decoder")
+            and hasattr(self.config.decoder, "vocab_size")
+        ):
+            vocab_size = self.config.decoder.vocab_size
+        else:
+            raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined")
+
+        # set effective batch size and effective batch multiplier according to do_sample
+        if do_sample:
+            effective_batch_size = batch_size * num_return_sequences
+            effective_batch_mult = num_return_sequences
+        else:
+            effective_batch_size = batch_size
+            effective_batch_mult = 1
+
+        if self.config.is_encoder_decoder:
+            if decoder_start_token_id is None:
+                # see if BOS token can be used for decoder_start_token_id
+                if bos_token_id is not None:
+                    decoder_start_token_id = bos_token_id
+                elif (
+                    hasattr(self.config, "decoder")
+                    and hasattr(self.config.decoder, "bos_token_id")
+                    and self.config.decoder.bos_token_id is not None
+                ):
+                    decoder_start_token_id = self.config.decoder.bos_token_id
+                else:
+                    raise ValueError(
+                        "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
+                    )
+
+            assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
+            assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
+
+            # get encoder and store encoder outputs
+            encoder = self.get_encoder()
+            encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
+
+        # Expand input ids if num_beams > 1 or num_return_sequences > 1
+        if num_return_sequences > 1 or num_beams > 1:
+            # TODO: make this a call-back function.
+            # input_ids=caps,
+            # input_video_embeds=vfeats,
+            # attention_mask=attention_mask,
+            # token_type_ids=token_type_ids,
+            input_video_embeds = model_kwargs.pop("input_video_embeds", None)
+            token_type_ids = model_kwargs.pop("token_type_ids", None)
+
+            input_ids_len = input_ids.shape[-1]
+            input_ids = input_ids.unsqueeze(1).expand(
+                 batch_size, effective_batch_mult * num_beams, input_ids_len)
+
+            input_video_embeds_len, input_video_embeds_hidden = input_video_embeds.size(1), input_video_embeds.size(2)
+            input_video_embeds = input_video_embeds.unsqueeze(1).expand(
+                batch_size, effective_batch_mult * num_beams, input_video_embeds_len, input_video_embeds_hidden)
+
+            attention_mask_from_len, attention_mask_to_len = attention_mask.size(1), attention_mask.size(2)
+            attention_mask = attention_mask.unsqueeze(1).expand(
+                batch_size, effective_batch_mult * num_beams, attention_mask_from_len, attention_mask_to_len
+            )
+
+            token_type_ids_len = token_type_ids.size(1)
+            token_type_ids = token_type_ids.unsqueeze(1).expand(
+                batch_size, effective_batch_mult * num_beams, token_type_ids_len
+            )
+
+            # contiguous ...
+            input_ids = input_ids.contiguous().view(
+                effective_batch_size * num_beams, input_ids_len
+            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
+
+            input_video_embeds = input_video_embeds.contiguous().view(
+                effective_batch_size * num_beams, input_video_embeds_len, input_video_embeds_hidden)
+
+            attention_mask = attention_mask.contiguous().view(
+                effective_batch_size * num_beams, attention_mask_from_len, attention_mask_to_len
+            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
+
+            token_type_ids = token_type_ids.contiguous().view(
+                effective_batch_size * num_beams, token_type_ids_len
+            )
+
+            model_kwargs["input_video_embeds"] = input_video_embeds
+            model_kwargs["token_type_ids"] = token_type_ids
+
+        if self.config.is_encoder_decoder:
+            device = next(self.parameters()).device
+            if decoder_input_ids is not None:
+                # give initial decoder input ids
+                input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device)
+            else:
+                # create empty decoder input_ids
+                input_ids = torch.full(
+                    (effective_batch_size * num_beams, 1),
+                    decoder_start_token_id,
+                    dtype=torch.long,
+                    device=device,
+                )
+            cur_len = input_ids.shape[-1]
+
+            assert (
+                batch_size == encoder_outputs.last_hidden_state.shape[0]
+            ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
+
+            # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
+            expanded_batch_idxs = (
+                torch.arange(batch_size)
+                .view(-1, 1)
+                .repeat(1, num_beams * effective_batch_mult)
+                .view(-1)
+                .to(input_ids.device)
+            )
+
+            # expand encoder_outputs
+            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
+                0, expanded_batch_idxs
+            )
+
+            # save encoder_outputs in `model_kwargs`
+            model_kwargs["encoder_outputs"] = encoder_outputs
+
+        else:
+            cur_len = input_ids.shape[-1]
+
+        assert (
+            cur_len < max_length
+        ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
+
+        if num_beams > 1:
+            output = self._generate_beam_search(
+                input_ids,
+                cur_len=cur_len,
+                max_length=max_length,
+                min_length=min_length,
+                do_sample=do_sample,
+                early_stopping=early_stopping,
+                temperature=temperature,
+                top_k=top_k,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+                no_repeat_ngram_size=no_repeat_ngram_size,
+                bad_words_ids=bad_words_ids,
+                pad_token_id=pad_token_id,
+                eos_token_id=eos_token_id,
+                batch_size=effective_batch_size,
+                num_return_sequences=num_return_sequences,
+                length_penalty=length_penalty,
+                num_beams=num_beams,
+                vocab_size=vocab_size,
+                attention_mask=attention_mask,
+                use_cache=use_cache,
+                model_kwargs=model_kwargs,
+            )
+        else:
+            output = self._generate_no_beam_search(
+                input_ids,
+                cur_len=cur_len,
+                max_length=max_length,
+                min_length=min_length,
+                do_sample=do_sample,
+                temperature=temperature,
+                top_k=top_k,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+                no_repeat_ngram_size=no_repeat_ngram_size,
+                bad_words_ids=bad_words_ids,
+                pad_token_id=pad_token_id,
+                eos_token_id=eos_token_id,
+                batch_size=effective_batch_size,
+                attention_mask=attention_mask,
+                use_cache=use_cache,
+                model_kwargs=model_kwargs,
+            )
+
+        return output
+
+    def _generate_beam_search(
+        self,
+        input_ids,
+        cur_len,
+        max_length,
+        min_length,
+        do_sample,
+        early_stopping,
+        temperature,
+        top_k,
+        top_p,
+        repetition_penalty,
+        no_repeat_ngram_size,
+        bad_words_ids,
+        pad_token_id,
+        eos_token_id,
+        batch_size,
+        num_return_sequences,
+        length_penalty,
+        num_beams,
+        vocab_size,
+        attention_mask,
+        use_cache,
+        model_kwargs,
+    ):
+        """Generate sequences for each example with beam search."""
+
+        # generated hypotheses
+        generated_hyps = [
+            BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
+            for _ in range(batch_size)
+        ]
+
+        # scores for each sentence in the beam
+        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
+
+        # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
+        if do_sample is False:
+            beam_scores[:, 1:] = -1e9
+        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)
+
+        # cache compute states
+        past = None
+
+        # done sentences
+        done = [False for _ in range(batch_size)]
+
+        while cur_len < max_length:
+            model_inputs = self.prepare_inputs_for_generation(
+                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
+            )
+            outputs = self(**model_inputs, return_dict=True)  # (batch_size * num_beams, cur_len, vocab_size)
+            next_token_logits = outputs.logits[:, -1, :]  # (batch_size * num_beams, vocab_size)
+
+            # if model has past, then set the past variable to speed up decoding
+            if "past_key_values" in outputs:
+                past = outputs.past_key_values
+            elif "mems" in outputs:
+                past = outputs.mems
+
+            if self.config.is_encoder_decoder and do_sample is False:
+                # TODO (PVP) still a bit hacky here - there might be a better solution
+                next_token_logits = self.adjust_logits_during_generation(
+                    next_token_logits, cur_len=cur_len, max_length=max_length
+                )
+
+            scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)
+
+            scores = self.postprocess_next_token_scores(
+                scores=scores,
+                input_ids=input_ids,
+                no_repeat_ngram_size=no_repeat_ngram_size,
+                bad_words_ids=bad_words_ids,
+                cur_len=cur_len,
+                min_length=min_length,
+                max_length=max_length,
+                eos_token_id=eos_token_id,
+                repetition_penalty=repetition_penalty,
+                batch_size=batch_size,
+                num_beams=num_beams,
+            )
+
+            assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
+                scores.shape, (batch_size * num_beams, vocab_size)
+            )
+
+            if do_sample:
+                _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
+                # Temperature
+                if temperature != 1.0:
+                    _scores = _scores / temperature
+                # Top-p/top-k filtering
+                _scores = top_k_top_p_filtering(
+                    _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
+                )  # (batch_size * num_beams, vocab_size)
+                # re-organize to group the beam together to sample from all beam_idxs
+                _scores = _scores.contiguous().view(
+                    batch_size, num_beams * vocab_size
+                )  # (batch_size, num_beams * vocab_size)
+
+                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
+                probs = F.softmax(_scores, dim=-1)
+                next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
+                # Compute next scores
+                next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
+                # sort the sampled vector to make sure that the first num_beams samples are the best
+                next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
+                next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)
+
+            else:
+                next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
+
+                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
+                next_scores = next_scores.view(
+                    batch_size, num_beams * vocab_size
+                )  # (batch_size, num_beams * vocab_size)
+
+                next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
+
+            assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
+
+            # next batch beam content
+            next_batch_beam = []
+
+            # for each sentence
+            for batch_idx in range(batch_size):
+
+                # if we are done with this sentence, add a pad token
+                if done[batch_idx]:
+                    assert (
+                        len(generated_hyps[batch_idx]) >= num_beams
+                    ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
+                    assert (
+                        eos_token_id is not None and pad_token_id is not None
+                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
+                    next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
+                    continue
+
+                # next sentence beam content, this will get added to next_batch_beam
+                next_sent_beam = []
+
+                # next tokens for this sentence
+                for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
+                    zip(next_tokens[batch_idx], next_scores[batch_idx])
+                ):
+                    # get beam and token IDs
+                    beam_id = beam_token_id // vocab_size
+                    token_id = beam_token_id % vocab_size
+
+                    effective_beam_id = batch_idx * num_beams + beam_id
+                    # add to generated hypotheses if end of sentence
+                    if (eos_token_id is not None) and (token_id.item() == eos_token_id):
+                        # if beam_token does not belong to top num_beams tokens, it should not be added
+                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
+                        if is_beam_token_worse_than_top_num_beams:
+                            continue
+                        generated_hyps[batch_idx].add(
+                            input_ids[effective_beam_id].clone(),
+                            beam_token_score.item(),
+                        )
+                    else:
+                        # add next predicted token since it is not eos_token
+                        next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
+
+                    # once the beam for next step is full, don't add more tokens to it.
+                    if len(next_sent_beam) == num_beams:
+                        break
+
+                # Check if we are done so that we can save a pad step if all(done)
+                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
+                    next_scores[batch_idx].max().item(), cur_len
+                )
+
+                # update next beam content
+                assert len(next_sent_beam) == num_beams, "Beam should always be full"
+                next_batch_beam.extend(next_sent_beam)
+                assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
+
+            # stop when we are done with each sentence
+            if all(done):
+                break
+
+            # sanity check / prepare next batch
+            assert len(next_batch_beam) == batch_size * num_beams
+            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
+            beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
+            beam_idx = input_ids.new([x[2] for x in next_batch_beam])
+
+            # re-order batch and update current length
+            input_ids = input_ids[beam_idx, :]
+            input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
+            cur_len = cur_len + 1
+
+            # re-order internal states
+            if past is not None:
+                past = self._reorder_cache(past, beam_idx)
+
+            # extend attention_mask for new generated input if only decoder
+            # (huxu): move out since we trim attention_mask by ourselves.
+            # if self.config.is_encoder_decoder is False:
+            #    attention_mask = torch.cat(
+            #        [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+            #    )
+
+        # finalize all open beam hypotheses and add to generated hypotheses
+        for batch_idx in range(batch_size):
+            if done[batch_idx]:
+                continue
+
+            # test that beam scores match previously calculated scores if not eos and batch_idx not done
+            if eos_token_id is not None and all(
+                (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
+            ):
+                assert torch.all(
+                    next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
+                ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
+                    next_scores[:, :num_beams][batch_idx],
+                    beam_scores.view(batch_size, num_beams)[batch_idx],
+                )
+
+            # need to add best num_beams hypotheses to generated hyps
+            for beam_id in range(num_beams):
+                effective_beam_id = batch_idx * num_beams + beam_id
+                final_score = beam_scores[effective_beam_id].item()
+                final_tokens = input_ids[effective_beam_id]
+                generated_hyps[batch_idx].add(final_tokens, final_score)
+
+        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
+        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
+        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
+
+        # select the best hypotheses
+        sent_lengths = input_ids.new(output_batch_size)
+        best = []
+
+        # retrieve best hypotheses
+        for i, hypotheses in enumerate(generated_hyps):
+            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
+            for j in range(output_num_return_sequences_per_batch):
+                effective_batch_idx = output_num_return_sequences_per_batch * i + j
+                best_hyp = sorted_hyps.pop()[1]
+                sent_lengths[effective_batch_idx] = len(best_hyp)
+                best.append(best_hyp)
+
+        # prepare for adding eos
+        sent_max_len = min(sent_lengths.max().item() + 1, max_length)
+        decoded = input_ids.new(output_batch_size, sent_max_len)
+        # shorter batches are padded if needed
+        if sent_lengths.min().item() != sent_lengths.max().item():
+            assert pad_token_id is not None, "`pad_token_id` has to be defined"
+            decoded.fill_(pad_token_id)
+
+        # fill with hypotheses and eos_token_id if the latter fits in
+        for i, hypo in enumerate(best):
+            decoded[i, : sent_lengths[i]] = hypo
+            if sent_lengths[i] < max_length:
+                decoded[i, sent_lengths[i]] = eos_token_id
+
+        return decoded
+
+    def _generate_no_beam_search(
+        self,
+        input_ids,
+        cur_len,
+        max_length,
+        min_length,
+        do_sample,
+        temperature,
+        top_k,
+        top_p,
+        repetition_penalty,
+        no_repeat_ngram_size,
+        bad_words_ids,
+        pad_token_id,
+        eos_token_id,
+        batch_size,
+        attention_mask,
+        use_cache,
+        model_kwargs,
+    ):
+        """Generate sequences for each example without beam search (num_beams == 1).
+        All returned sequence are generated independantly.
+        """
+        # length of generated sentences / unfinished sentences
+        unfinished_sents = input_ids.new(batch_size).fill_(1)
+        sent_lengths = input_ids.new(batch_size).fill_(max_length)
+
+        past = None
+        while cur_len < max_length:
+            model_inputs = self.prepare_inputs_for_generation(
+                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
+            )
+
+            outputs = self(**model_inputs, return_dict=True)
+            next_token_logits = outputs.logits[:, -1, :]
+            scores = self.postprocess_next_token_scores(
+                scores=next_token_logits,
+                input_ids=input_ids,
+                no_repeat_ngram_size=no_repeat_ngram_size,
+                bad_words_ids=bad_words_ids,
+                cur_len=cur_len,
+                min_length=min_length,
+                max_length=max_length,
+                eos_token_id=eos_token_id,
+                repetition_penalty=repetition_penalty,
+                batch_size=batch_size,
+                num_beams=1,
+            )
+
+            # if model has past, then set the past variable to speed up decoding
+            if "past_key_values" in outputs:
+                past = outputs.past_key_values
+            elif "mems" in outputs:
+                past = outputs.mems
+
+            if do_sample:
+                # Temperature (higher temperature => more likely to sample low probability tokens)
+                if temperature != 1.0:
+                    scores = scores / temperature
+                # Top-p/top-k filtering
+                next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
+                # Sample
+                probs = F.softmax(next_token_logscores, dim=-1)
+                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
+            else:
+                # Greedy decoding
+                next_token = torch.argmax(next_token_logits, dim=-1)
+            
+                # print(next_token_logits[0,next_token[0]], next_token_logits[0,eos_token_id])
+
+            # update generations and finished sentences
+            if eos_token_id is not None:
+                # pad finished sentences if eos_token_id exist
+                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
+            else:
+                tokens_to_add = next_token
+
+            # add token and increase length by one
+            input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
+            cur_len = cur_len + 1
+
+            if eos_token_id is not None:
+                eos_in_sents = tokens_to_add == eos_token_id
+                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
+                is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
+                sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
+                # unfinished_sents is set to zero if eos in sentence
+                unfinished_sents.mul_((~eos_in_sents).long())
+
+            # stop when there is a </s> in each sentence, or if we exceed the maximul length
+            if unfinished_sents.max() == 0:
+                break
+            
+            
+            # extend attention_mask for new generated input if only decoder
+            # if self.config.is_encoder_decoder is False:
+            #     attention_mask = torch.cat(
+            #         [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+            #     )
+
+        return input_ids

+ 734 - 0
examples/MMPT/mmpt/models/transformermodel.py

@@ -0,0 +1,734 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch
+
+from torch import nn
+
+try:
+    from transformers.modeling_bert import (
+        BertPreTrainedModel,
+        BertModel,
+        BertEncoder,
+        BertPredictionHeadTransform,
+    )
+except ImportError:
+    pass
+
+from ..modules import VideoTokenMLP, MMBertEmbeddings
+
+
+# --------------- fine-tuning models ---------------
+class MMBertForJoint(BertPreTrainedModel):
+    """A BertModel with isolated attention mask to separate modality."""
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.videomlp = VideoTokenMLP(config)
+        self.bert = MMBertModel(config)
+        self.init_weights()
+
+    def forward(
+        self,
+        input_ids=None,
+        input_video_embeds=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        next_sentence_label=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        separate_forward_split=None,
+    ):
+        return_dict = (
+            return_dict if return_dict is not None
+            else self.config.use_return_dict
+        )
+        video_tokens = self.videomlp(input_video_embeds)
+
+        outputs = self.bert(
+            input_ids,
+            video_tokens,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            separate_forward_split=separate_forward_split,
+        )
+
+        return outputs
+
+
+class MMBertForTokenClassification(BertPreTrainedModel):
+    """A BertModel similar to MMJointUni, with extra wrapper layer
+    to be fine-tuned from other pretrained MMFusion model."""
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.videomlp = VideoTokenMLP(config)
+        self.bert = MMBertModel(config)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        # TODO(huxu): 779 is the number of classes for COIN: move to config?
+        self.classifier = nn.Linear(config.hidden_size, 779)
+        self.init_weights()
+
+    def forward(
+        self,
+        input_ids=None,
+        input_video_embeds=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        next_sentence_label=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        separate_forward_split=None,
+    ):
+        return_dict = (
+            return_dict if return_dict is not None
+            else self.config.use_return_dict
+        )
+
+        video_tokens = self.videomlp(input_video_embeds)
+        outputs = self.bert(
+            input_ids,
+            video_tokens,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            separate_forward_split=separate_forward_split,
+        )
+
+        return (self.classifier(outputs[0]),)
+
+
+# ------------ pre-training models ----------------
+
+class MMBertForEncoder(BertPreTrainedModel):
+    """A BertModel for Contrastive Learning."""
+    def __init__(self, config):
+        super().__init__(config)
+        self.videomlp = VideoTokenMLP(config)
+        self.bert = MMBertModel(config)
+        self.init_weights()
+
+    def forward(
+        self,
+        input_ids=None,
+        input_video_embeds=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        return_dict = (
+            return_dict if return_dict is not None
+            else self.config.use_return_dict
+        )
+        if input_video_embeds is not None:
+            video_tokens = self.videomlp(input_video_embeds)
+        else:
+            video_tokens = None
+
+        outputs = self.bert(
+            input_ids,
+            video_tokens,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        return outputs
+
+
+class MMBertForMFMMLM(BertPreTrainedModel):
+    """A BertModel with shared prediction head on MFM-MLM."""
+    def __init__(self, config):
+        super().__init__(config)
+        self.videomlp = VideoTokenMLP(config)
+        self.bert = MMBertModel(config)
+        self.cls = MFMMLMHead(config)
+        self.hidden_size = config.hidden_size
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def forward(
+        self,
+        input_ids=None,
+        input_video_embeds=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        masked_frame_labels=None,
+        target_video_hidden_states=None,
+        non_masked_frame_mask=None,
+        masked_lm_labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        return_dict = (
+            return_dict if return_dict is not None
+            else self.config.use_return_dict
+        )
+        if input_video_embeds is not None:
+            video_tokens = self.videomlp(input_video_embeds)
+        else:
+            video_tokens = None
+
+        if target_video_hidden_states is not None:
+            target_video_hidden_states = self.videomlp(
+                target_video_hidden_states)
+
+            non_masked_frame_hidden_states = video_tokens.masked_select(
+                non_masked_frame_mask.unsqueeze(-1)
+            ).view(-1, self.hidden_size)
+
+        outputs = self.bert(
+            input_ids,
+            video_tokens,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = outputs[0]
+
+        mfm_scores, prediction_scores = None, None
+        if masked_frame_labels is not None and masked_lm_labels is not None:
+            # split the sequence.
+            text_offset = masked_frame_labels.size(1) + 1  # [CLS]
+            video_sequence_output = sequence_output[
+                :, 1:text_offset
+            ]  # remove [SEP] as not in video_label.
+            text_sequence_output = torch.cat(
+                [sequence_output[:, :1], sequence_output[:, text_offset:]],
+                dim=1
+            )
+
+            hidden_size = video_sequence_output.size(-1)
+            selected_video_output = video_sequence_output.masked_select(
+                masked_frame_labels.unsqueeze(-1)
+            ).view(-1, hidden_size)
+
+            # only compute select tokens to training to speed up.
+            hidden_size = text_sequence_output.size(-1)
+            # masked_lm_labels = masked_lm_labels.reshape(-1)
+            labels_mask = masked_lm_labels != -100
+
+            selected_text_output = text_sequence_output.masked_select(
+                labels_mask.unsqueeze(-1)
+            ).view(-1, hidden_size)
+            mfm_scores, prediction_scores = self.cls(
+                selected_video_output,
+                target_video_hidden_states,
+                non_masked_frame_hidden_states,
+                selected_text_output,
+            )
+
+        output = (
+            mfm_scores,
+            prediction_scores,
+        ) + outputs
+        return output
+
+
+class BertMFMMLMPredictionHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.transform = BertPredictionHeadTransform(config)
+        # The output weights are the same as the input embeddings, but there is
+        # an output-only bias for each token.
+        self.decoder = nn.Linear(
+            config.hidden_size, config.vocab_size, bias=False)
+
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+        # Need a link between the two variables so that the bias is correctly
+        # resized with `resize_token_embeddings`
+        self.decoder.bias = self.bias
+
+    def forward(
+        self,
+        video_hidden_states=None,
+        target_video_hidden_states=None,
+        non_masked_frame_hidden_states=None,
+        text_hidden_states=None,
+    ):
+        video_logits, text_logits = None, None
+        if video_hidden_states is not None:
+            video_hidden_states = self.transform(video_hidden_states)
+            non_masked_frame_logits = torch.mm(
+                video_hidden_states,
+                non_masked_frame_hidden_states.transpose(1, 0)
+            )
+            masked_frame_logits = torch.bmm(
+                video_hidden_states.unsqueeze(1),
+                target_video_hidden_states.unsqueeze(-1),
+            ).squeeze(-1)
+            video_logits = torch.cat(
+                [masked_frame_logits, non_masked_frame_logits], dim=1
+            )
+
+        if text_hidden_states is not None:
+            text_hidden_states = self.transform(text_hidden_states)
+            text_logits = self.decoder(text_hidden_states)
+        return video_logits, text_logits
+
+
+class MFMMLMHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.predictions = BertMFMMLMPredictionHead(config)
+
+    def forward(
+        self,
+        video_hidden_states=None,
+        target_video_hidden_states=None,
+        non_masked_frame_hidden_states=None,
+        text_hidden_states=None,
+    ):
+        video_logits, text_logits = self.predictions(
+            video_hidden_states,
+            target_video_hidden_states,
+            non_masked_frame_hidden_states,
+            text_hidden_states,
+        )
+        return video_logits, text_logits
+
+
+class MMBertForMTM(MMBertForMFMMLM):
+    def __init__(self, config):
+        BertPreTrainedModel.__init__(self, config)
+        self.videomlp = VideoTokenMLP(config)
+        self.bert = MMBertModel(config)
+        self.cls = MTMHead(config)
+        self.hidden_size = config.hidden_size
+        self.init_weights()
+
+
+class BertMTMPredictionHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.transform = BertPredictionHeadTransform(config)
+        self.decoder = nn.Linear(
+            config.hidden_size, config.vocab_size, bias=False)
+
+    def forward(
+        self,
+        video_hidden_states=None,
+        target_video_hidden_states=None,
+        non_masked_frame_hidden_states=None,
+        text_hidden_states=None,
+    ):
+        non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0)
+        video_logits, text_logits = None, None
+        if video_hidden_states is not None:
+            video_hidden_states = self.transform(video_hidden_states)
+
+            masked_frame_logits = torch.bmm(
+                video_hidden_states.unsqueeze(1),
+                target_video_hidden_states.unsqueeze(-1),
+            ).squeeze(-1)
+
+            non_masked_frame_logits = torch.mm(
+                video_hidden_states,
+                non_masked_frame_hidden_states
+            )
+            video_on_vocab_logits = self.decoder(video_hidden_states)
+            video_logits = torch.cat([
+                masked_frame_logits,
+                non_masked_frame_logits,
+                video_on_vocab_logits], dim=1)
+
+        if text_hidden_states is not None:
+            text_hidden_states = self.transform(text_hidden_states)
+            # text first so label does not need to be shifted.
+            text_on_vocab_logits = self.decoder(text_hidden_states)
+            text_on_video_logits = torch.mm(
+                text_hidden_states,
+                non_masked_frame_hidden_states
+            )
+            text_logits = torch.cat([
+                text_on_vocab_logits,
+                text_on_video_logits
+            ], dim=1)
+
+        return video_logits, text_logits
+
+
+class MTMHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.predictions = BertMTMPredictionHead(config)
+
+    def forward(
+        self,
+        video_hidden_states=None,
+        target_video_hidden_states=None,
+        non_masked_frame_hidden_states=None,
+        text_hidden_states=None,
+    ):
+        video_logits, text_logits = self.predictions(
+            video_hidden_states,
+            target_video_hidden_states,
+            non_masked_frame_hidden_states,
+            text_hidden_states,
+        )
+        return video_logits, text_logits
+
+
+class MMBertModel(BertModel):
+    """MMBertModel has MMBertEmbedding to support video tokens."""
+
+    def __init__(self, config, add_pooling_layer=True):
+        super().__init__(config)
+        # overwrite embedding
+        self.embeddings = MMBertEmbeddings(config)
+        self.encoder = MultiLayerAttentionMaskBertEncoder(config)
+        self.init_weights()
+
+    def forward(
+        self,
+        input_ids=None,
+        input_video_embeds=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        separate_forward_split=None,
+    ):
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None
+            else self.config.use_return_dict
+        )
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both input_ids "
+                "and inputs_embeds at the same time"
+            )
+        elif input_ids is not None:
+            if input_video_embeds is not None:
+                input_shape = (
+                    input_ids.size(0),
+                    input_ids.size(1) + input_video_embeds.size(1),
+                )
+            else:
+                input_shape = (
+                    input_ids.size(0),
+                    input_ids.size(1),
+                )
+        elif inputs_embeds is not None:
+            if input_video_embeds is not None:
+                input_shape = (
+                    inputs_embeds.size(0),
+                    inputs_embeds.size(1) + input_video_embeds.size(1),
+                )
+            else:
+                input_shape = (
+                    input_ids.size(0),
+                    input_ids.size(1),
+                )
+        else:
+            raise ValueError(
+                "You have to specify either input_ids or inputs_embeds")
+
+        device = input_ids.device if input_ids is not None \
+            else inputs_embeds.device
+
+        if attention_mask is None:
+            attention_mask = torch.ones(input_shape, device=device)
+        if token_type_ids is None:
+            token_type_ids = torch.zeros(
+                input_shape, dtype=torch.long, device=device)
+
+        # We can provide a self-attention mask of dimensions
+        # [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case
+        # we just need to make it broadcastable to all heads.
+        extended_attention_mask: torch.Tensor = \
+            self.get_extended_attention_mask(
+                attention_mask, input_shape, device)
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to
+        # [batch_size, num_heads, seq_length, seq_length]
+        if self.config.is_decoder and encoder_hidden_states is not None:
+            (
+                encoder_batch_size,
+                encoder_sequence_length,
+                _,
+            ) = encoder_hidden_states.size()
+            encoder_hidden_shape = (
+                encoder_batch_size, encoder_sequence_length)
+            if encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(
+                    encoder_hidden_shape, device=device)
+            encoder_extended_attention_mask = self.invert_attention_mask(
+                encoder_attention_mask
+            )
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or
+        # [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape
+        # [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+
+        head_mask = self.get_head_mask(
+            head_mask, self.config.num_hidden_layers)
+
+        embedding_output = self.embeddings(
+            input_ids,
+            input_video_embeds,
+            position_ids=position_ids,
+            token_type_ids=token_type_ids,
+            inputs_embeds=inputs_embeds,
+        )
+
+        if separate_forward_split is not None:
+            split_embedding_output = \
+                embedding_output[:, :separate_forward_split]
+            split_extended_attention_mask = extended_attention_mask[
+                :, :, :, :separate_forward_split, :separate_forward_split
+            ]
+            split_encoder_outputs = self.encoder(
+                split_embedding_output,
+                attention_mask=split_extended_attention_mask,
+                head_mask=head_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_extended_attention_mask,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+            assert (
+                len(split_encoder_outputs) <= 2
+            ), "we do not support merge on attention for now."
+            encoder_outputs = []
+            encoder_outputs.append([split_encoder_outputs[0]])
+            if len(split_encoder_outputs) == 2:
+                encoder_outputs.append([])
+                for _all_hidden_states in split_encoder_outputs[1]:
+                    encoder_outputs[-1].append([_all_hidden_states])
+
+            split_embedding_output = \
+                embedding_output[:, separate_forward_split:]
+            split_extended_attention_mask = extended_attention_mask[
+                :, :, :, separate_forward_split:, separate_forward_split:
+            ]
+
+            split_encoder_outputs = self.encoder(
+                split_embedding_output,
+                attention_mask=split_extended_attention_mask,
+                head_mask=head_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_extended_attention_mask,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+
+            assert (
+                len(split_encoder_outputs) <= 2
+            ), "we do not support merge on attention for now."
+            encoder_outputs[0].append(split_encoder_outputs[0])
+            encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1)
+            if len(split_encoder_outputs) == 2:
+                for layer_idx, _all_hidden_states in enumerate(
+                    split_encoder_outputs[1]
+                ):
+                    encoder_outputs[1][layer_idx].append(_all_hidden_states)
+                    encoder_outputs[1][layer_idx] = torch.cat(
+                        encoder_outputs[1][layer_idx], dim=1
+                    )
+            encoder_outputs = tuple(encoder_outputs)
+        else:
+            encoder_outputs = self.encoder(
+                embedding_output,
+                attention_mask=extended_attention_mask,
+                head_mask=head_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_extended_attention_mask,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+
+        sequence_output = encoder_outputs[0]
+        pooled_output = (
+            self.pooler(sequence_output) if self.pooler is not None else None
+        )
+
+        return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+    def get_extended_attention_mask(self, attention_mask, input_shape, device):
+        """This is borrowed from `modeling_utils.py` with the support of
+        multi-layer attention masks.
+        The second dim is expected to be number of layers.
+        See `MMAttentionMaskProcessor`.
+        Makes broadcastable attention and causal masks so that future
+        and masked tokens are ignored.
+
+        Arguments:
+            attention_mask (:obj:`torch.Tensor`):
+                Mask with ones indicating tokens to attend to,
+                zeros for tokens to ignore.
+            input_shape (:obj:`Tuple[int]`):
+                The shape of the input to the model.
+            device: (:obj:`torch.device`):
+                The device of the input to the model.
+
+        Returns:
+            :obj:`torch.Tensor` The extended attention mask, \
+                with a the same dtype as :obj:`attention_mask.dtype`.
+        """
+        # We can provide a self-attention mask of dimensions
+        # [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable
+        # to all heads.
+        if attention_mask.dim() == 4:
+            extended_attention_mask = attention_mask[:, :, None, :, :]
+            extended_attention_mask = extended_attention_mask.to(
+                dtype=self.dtype
+            )  # fp16 compatibility
+            extended_attention_mask = (1.0 - extended_attention_mask) \
+                * -10000.0
+            return extended_attention_mask
+        else:
+            return super().get_extended_attention_mask(
+                attention_mask, input_shape, device
+            )
+
+
+class MultiLayerAttentionMaskBertEncoder(BertEncoder):
+    """extend BertEncoder with the capability of
+    multiple layers of attention mask."""
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=False,
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            layer_attention_mask = (
+                attention_mask[:, i, :, :, :]
+                if attention_mask.dim() == 5
+                else attention_mask
+            )
+
+            if getattr(self.config, "gradient_checkpointing", False):
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs, output_attentions)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(layer_module),
+                    hidden_states,
+                    layer_attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states,
+                    layer_attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    output_attentions,
+                )
+            hidden_states = layer_outputs[0]
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        return tuple(
+            v
+            for v in [hidden_states, all_hidden_states, all_attentions]
+            if v is not None
+        )

+ 10 - 0
examples/MMPT/mmpt/modules/__init__.py

@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .mm import *
+
+try:
+    from .expmm import *
+except ImportError:
+    pass

+ 145 - 0
examples/MMPT/mmpt/modules/mm.py

@@ -0,0 +1,145 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+
+import torch
+
+from torch import nn
+
+try:
+    from transformers.modeling_bert import (
+        BertEmbeddings,
+        ACT2FN,
+    )
+except ImportError:
+    pass
+
+
+class VideoTokenMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        input_dim = config.input_dim if hasattr(config, "input_dim") else 512
+        self.linear1 = nn.Linear(input_dim, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size)
+        self.activation = ACT2FN[config.hidden_act]
+        self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
+
+    def forward(self, hidden_states):
+        hidden_states = self.linear1(hidden_states)
+        hidden_states = self.activation(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states)
+        hidden_states = self.linear2(hidden_states)
+        return hidden_states
+
+
+class MMBertEmbeddings(BertEmbeddings):
+    def __init__(self, config):
+        super().__init__(config)
+        self.max_video_len = config.max_video_len
+        if hasattr(config, "use_seg_emb") and config.use_seg_emb:
+            """the original VLM paper uses seg_embeddings for temporal space.
+            although not used it changed the randomness of initialization.
+            we keep it for reproducibility.
+            """
+            self.seg_embeddings = nn.Embedding(256, config.hidden_size)
+
+    def forward(
+        self,
+        input_ids,
+        input_video_embeds,
+        token_type_ids=None,
+        position_ids=None,
+        inputs_embeds=None,
+    ):
+        input_tensor = input_ids if input_ids is not None else inputs_embeds
+        if input_video_embeds is not None:
+            input_shape = (
+                input_tensor.size(0),
+                input_tensor.size(1) + input_video_embeds.size(1),
+            )
+        else:
+            input_shape = (input_tensor.size(0), input_tensor.size(1))
+
+        if position_ids is None:
+            """
+            Auto skip position embeddings for text only case.
+            use cases:
+            (1) action localization and segmentation:
+                feed in len-1 dummy video token needs text part to
+                skip input_video_embeds.size(1) for the right
+                position_ids for video [SEP] and rest text tokens.
+            (2) MMFusionShare for two forward passings:
+                in `forward_text`: input_video_embeds is None.
+                    need to skip video [SEP] token.
+
+            # video_len + 1: [CLS] + video_embed
+            # self.max_video_len + 1: [SEP] for video.
+            # self.max_video_len + 2: [SEP] for video.
+            # self.max_video_len + input_ids.size(1): rest for text.
+            """
+            if input_video_embeds is not None:
+                video_len = input_video_embeds.size(1)
+                starting_offset = self.max_video_len + 1  # video [SEP]
+                ending_offset = self.max_video_len + input_ids.size(1)
+            else:
+                video_len = 0
+                starting_offset = self.max_video_len + 2  # first text token.
+                ending_offset = self.max_video_len + input_ids.size(1) + 1
+            position_ids = torch.cat([
+                self.position_ids[:, :video_len + 1],
+                self.position_ids[:, starting_offset:ending_offset]
+                ], dim=1)
+
+        if token_type_ids is None:
+            token_type_ids = torch.zeros(
+                input_shape, dtype=torch.long, device=self.position_ids.device
+            )
+
+        """
+        the format of input_ids is [CLS] [SEP] caption [SEP] padding.
+        the goal is to build [CLS] video tokens [SEP] caption [SEP] .
+        """
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+        if input_video_embeds is not None:
+            inputs_mm_embeds = torch.cat([
+                inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:]
+            ], dim=1)
+        else:
+            # text only for `MMFusionShare`.
+            inputs_mm_embeds = inputs_embeds
+
+        position_embeddings = self.position_embeddings(position_ids)
+        token_type_embeddings = self.token_type_embeddings(token_type_ids)
+        embeddings = inputs_mm_embeds + position_embeddings
+        embeddings += token_type_embeddings
+
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class AlignHead(nn.Module):
+    """this will load pre-trained weights for NSP, which is desirable."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+    def forward(self, dropout_pooled_output):
+        logits = self.seq_relationship(dropout_pooled_output)
+        return logits

+ 429 - 0
examples/MMPT/mmpt/modules/retri.py

@@ -0,0 +1,429 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import numpy as np
+import pickle
+import time
+
+try:
+    import faiss
+except ImportError:
+    pass
+
+from collections import defaultdict
+
+from ..utils import get_local_rank, print_on_rank0
+
+
+class VectorRetriever(object):
+    """
+    How2 Video Retriver.
+    Reference usage of FAISS:
+    https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
+    """
+
+    def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
+        if db_type == "flatl2":
+            quantizer = faiss.IndexFlatL2(hidden_size)  # the other index
+            self.db = faiss.IndexIVFFlat(
+                quantizer, hidden_size, cent, faiss.METRIC_L2)
+        elif db_type == "pq":
+            self.db = faiss.index_factory(
+                    hidden_size, f"IVF{cent}_HNSW32,PQ32"
+            )
+        else:
+            raise ValueError("unknown type of db", db_type)
+        self.train_thres = cent * examples_per_cent_to_train
+        self.train_cache = []
+        self.train_len = 0
+        self.videoid_to_vectoridx = {}
+        self.vectoridx_to_videoid = None
+        self.make_direct_maps_done = False
+
+    def make_direct_maps(self):
+        faiss.downcast_index(self.db).make_direct_map()
+
+    def __len__(self):
+        return self.db.ntotal
+
+    def save(self, out_dir):
+        faiss.write_index(
+            self.db,
+            os.path.join(out_dir, "faiss_idx")
+        )
+        with open(
+                os.path.join(
+                    out_dir, "videoid_to_vectoridx.pkl"),
+                "wb") as fw:
+            pickle.dump(
+                self.videoid_to_vectoridx, fw,
+                protocol=pickle.HIGHEST_PROTOCOL
+            )
+
+    def load(self, out_dir):
+        fn = os.path.join(out_dir, "faiss_idx")
+        self.db = faiss.read_index(fn)
+        with open(
+                os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
+            self.videoid_to_vectoridx = pickle.load(fr)
+
+    def add(self, hidden_states, video_ids, last=False):
+        assert len(hidden_states) == len(video_ids), "{}, {}".format(
+            str(len(hidden_states)), str(len(video_ids)))
+        assert len(hidden_states.shape) == 2
+        assert hidden_states.dtype == np.float32
+
+        valid_idx = []
+        for idx, video_id in enumerate(video_ids):
+            if video_id not in self.videoid_to_vectoridx:
+                valid_idx.append(idx)
+                self.videoid_to_vectoridx[video_id] = \
+                    len(self.videoid_to_vectoridx)
+
+        hidden_states = hidden_states[valid_idx]
+        if not self.db.is_trained:
+            self.train_cache.append(hidden_states)
+            self.train_len += hidden_states.shape[0]
+            if self.train_len < self.train_thres:
+                return
+            self.finalize_training()
+        else:
+            self.db.add(hidden_states)
+
+    def finalize_training(self):
+        hidden_states = np.concatenate(self.train_cache, axis=0)
+        del self.train_cache
+        local_rank = get_local_rank()
+        if local_rank == 0:
+            start = time.time()
+            print("training db on", self.train_thres, "/", self.train_len)
+        self.db.train(hidden_states[:self.train_thres])
+        if local_rank == 0:
+            print("training db for", time.time() - start)
+        self.db.add(hidden_states)
+
+    def search(
+        self,
+        query_hidden_states,
+        orig_dist,
+    ):
+        if len(self.videoid_to_vectoridx) != self.db.ntotal:
+            raise ValueError(
+                "cannot search: size mismatch in-between index and db",
+                len(self.videoid_to_vectoridx),
+                self.db.ntotal
+            )
+
+        if self.vectoridx_to_videoid is None:
+            self.vectoridx_to_videoid = {
+                self.videoid_to_vectoridx[videoid]: videoid
+                for videoid in self.videoid_to_vectoridx
+            }
+            assert len(self.vectoridx_to_videoid) \
+                == len(self.videoid_to_vectoridx)
+
+        # MultilingualFaissDataset uses the following; not sure the purpose.
+        # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+        queried_dist, index = self.db.search(query_hidden_states, 1)
+        queried_dist, index = queried_dist[:, 0], index[:, 0]
+
+        outputs = np.array(
+            [self.vectoridx_to_videoid[_index]
+                if _index != -1 else (-1, -1, -1) for _index in index],
+            dtype=np.int32)
+        outputs[queried_dist <= orig_dist] = -1
+        return outputs
+
+    def search_by_video_ids(
+        self,
+        video_ids,
+        retri_factor
+    ):
+        if len(self.videoid_to_vectoridx) != self.db.ntotal:
+            raise ValueError(
+                len(self.videoid_to_vectoridx),
+                self.db.ntotal
+            )
+
+        if not self.make_direct_maps_done:
+            self.make_direct_maps()
+
+        if self.vectoridx_to_videoid is None:
+            self.vectoridx_to_videoid = {
+                self.videoid_to_vectoridx[videoid]: videoid
+                for videoid in self.videoid_to_vectoridx
+            }
+            assert len(self.vectoridx_to_videoid) \
+                == len(self.videoid_to_vectoridx)
+
+        query_hidden_states = []
+        vector_ids = []
+        for video_id in video_ids:
+            vector_id = self.videoid_to_vectoridx[video_id]
+            vector_ids.append(vector_id)
+            query_hidden_state = self.db.reconstruct(vector_id)
+            query_hidden_states.append(query_hidden_state)
+        query_hidden_states = np.stack(query_hidden_states)
+
+        # MultilingualFaissDataset uses the following; not sure the reason.
+        # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+        _, index = self.db.search(query_hidden_states, retri_factor)
+        outputs = []
+        for sample_idx, sample in enumerate(index):
+            # the first video_id is always the video itself.
+            cands = [video_ids[sample_idx]]
+            for vector_idx in sample:
+                if vector_idx >= 0 \
+                        and vector_ids[sample_idx] != vector_idx:
+                    cands.append(
+                        self.vectoridx_to_videoid[vector_idx]
+                    )
+            outputs.append(cands)
+        return outputs
+
+
+class VectorRetrieverDM(VectorRetriever):
+    """
+    with direct map.
+    How2 Video Retriver.
+    Reference usage of FAISS:
+    https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
+    """
+
+    def __init__(
+        self,
+        hidden_size,
+        cent,
+        db_type,
+        examples_per_cent_to_train
+    ):
+        super().__init__(
+            hidden_size, cent, db_type, examples_per_cent_to_train)
+        self.make_direct_maps_done = False
+
+    def make_direct_maps(self):
+        faiss.downcast_index(self.db).make_direct_map()
+        self.make_direct_maps_done = True
+
+    def search(
+        self,
+        query_hidden_states,
+        orig_dist,
+    ):
+        if len(self.videoid_to_vectoridx) != self.db.ntotal:
+            raise ValueError(
+                len(self.videoid_to_vectoridx),
+                self.db.ntotal
+            )
+
+        if not self.make_direct_maps_done:
+            self.make_direct_maps()
+        if self.vectoridx_to_videoid is None:
+            self.vectoridx_to_videoid = {
+                self.videoid_to_vectoridx[videoid]: videoid
+                for videoid in self.videoid_to_vectoridx
+            }
+            assert len(self.vectoridx_to_videoid) \
+                == len(self.videoid_to_vectoridx)
+
+        # MultilingualFaissDataset uses the following; not sure the reason.
+        # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+        queried_dist, index = self.db.search(query_hidden_states, 1)
+        outputs = []
+        for sample_idx, sample in enumerate(index):
+            # and queried_dist[sample_idx] < thres \
+            if sample >= 0 \
+                    and queried_dist[sample_idx] < orig_dist[sample_idx]:
+                outputs.append(self.vectoridx_to_videoid[sample])
+            else:
+                outputs.append(None)
+        return outputs
+
+    def search_by_video_ids(
+        self,
+        video_ids,
+        retri_factor=8
+    ):
+        if len(self.videoid_to_vectoridx) != self.db.ntotal:
+            raise ValueError(
+                len(self.videoid_to_vectoridx),
+                self.db.ntotal
+            )
+
+        if not self.make_direct_maps_done:
+            self.make_direct_maps()
+        if self.vectoridx_to_videoid is None:
+            self.vectoridx_to_videoid = {
+                self.videoid_to_vectoridx[videoid]: videoid
+                for videoid in self.videoid_to_vectoridx
+            }
+            assert len(self.vectoridx_to_videoid) \
+                == len(self.videoid_to_vectoridx)
+
+        query_hidden_states = []
+        vector_ids = []
+        for video_id in video_ids:
+            vector_id = self.videoid_to_vectoridx[video_id]
+            vector_ids.append(vector_id)
+            query_hidden_state = self.db.reconstruct(vector_id)
+            query_hidden_states.append(query_hidden_state)
+        query_hidden_states = np.stack(query_hidden_states)
+
+        # MultilingualFaissDataset uses the following; not sure the reason.
+        # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+        _, index = self.db.search(query_hidden_states, retri_factor)
+        outputs = []
+        for sample_idx, sample in enumerate(index):
+            # the first video_id is always the video itself.
+            cands = [video_ids[sample_idx]]
+            for vector_idx in sample:
+                if vector_idx >= 0 \
+                        and vector_ids[sample_idx] != vector_idx:
+                    cands.append(
+                        self.vectoridx_to_videoid[vector_idx]
+                    )
+            outputs.append(cands)
+        return outputs
+
+
+class MMVectorRetriever(VectorRetrieverDM):
+    """
+    multimodal vector retriver:
+    text retrieve video or video retrieve text.
+    """
+
+    def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
+        super().__init__(
+            hidden_size, cent, db_type, examples_per_cent_to_train)
+        video_db = self.db
+        super().__init__(
+            hidden_size, cent, db_type, examples_per_cent_to_train)
+        text_db = self.db
+        self.db = {"video": video_db, "text": text_db}
+        self.video_to_videoid = defaultdict(list)
+
+    def __len__(self):
+        assert self.db["video"].ntotal == self.db["text"].ntotal
+        return self.db["video"].ntotal
+
+    def make_direct_maps(self):
+        faiss.downcast_index(self.db["video"]).make_direct_map()
+        faiss.downcast_index(self.db["text"]).make_direct_map()
+
+    def save(self, out_dir):
+        faiss.write_index(
+            self.db["video"],
+            os.path.join(out_dir, "video_faiss_idx")
+        )
+        faiss.write_index(
+            self.db["text"],
+            os.path.join(out_dir, "text_faiss_idx")
+        )
+
+        with open(
+                os.path.join(
+                    out_dir, "videoid_to_vectoridx.pkl"),
+                "wb") as fw:
+            pickle.dump(
+                self.videoid_to_vectoridx, fw,
+                protocol=pickle.HIGHEST_PROTOCOL
+            )
+
+    def load(self, out_dir):
+        fn = os.path.join(out_dir, "video_faiss_idx")
+        video_db = faiss.read_index(fn)
+        fn = os.path.join(out_dir, "text_faiss_idx")
+        text_db = faiss.read_index(fn)
+        self.db = {"video": video_db, "text": text_db}
+        with open(
+                os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
+            self.videoid_to_vectoridx = pickle.load(fr)
+        self.video_to_videoid = defaultdict(list)
+
+    def add(self, hidden_states, video_ids):
+        """hidden_states is a pair `(video, text)`"""
+        assert len(hidden_states) == len(video_ids), "{}, {}".format(
+            str(len(hidden_states)), str(len(video_ids)))
+        assert len(hidden_states.shape) == 3
+        assert len(self.video_to_videoid) == 0
+
+        valid_idx = []
+        for idx, video_id in enumerate(video_ids):
+            if video_id not in self.videoid_to_vectoridx:
+                valid_idx.append(idx)
+                self.videoid_to_vectoridx[video_id] = \
+                    len(self.videoid_to_vectoridx)
+
+        batch_size = hidden_states.shape[0]
+        hidden_states = hidden_states[valid_idx]
+
+        hidden_states = np.transpose(hidden_states, (1, 0, 2)).copy()
+        if not self.db["video"].is_trained:
+            self.train_cache.append(hidden_states)
+            train_len = batch_size * len(self.train_cache)
+            if train_len < self.train_thres:
+                return
+
+            hidden_states = np.concatenate(self.train_cache, axis=1)
+            del self.train_cache
+            self.db["video"].train(hidden_states[0, :self.train_thres])
+            self.db["text"].train(hidden_states[1, :self.train_thres])
+        self.db["video"].add(hidden_states[0])
+        self.db["text"].add(hidden_states[1])
+
+    def get_clips_by_video_id(self, video_id):
+        if not self.video_to_videoid:
+            for video_id, video_clip, text_clip in self.videoid_to_vectoridx:
+                self.video_to_videoid[video_id].append(
+                    (video_id, video_clip, text_clip))
+        return self.video_to_videoid[video_id]
+
+    def search(
+        self,
+        video_ids,
+        target_modality,
+        retri_factor=8
+    ):
+        if len(self.videoid_to_vectoridx) != len(self):
+            raise ValueError(
+                len(self.videoid_to_vectoridx),
+                len(self)
+            )
+
+        if not self.make_direct_maps_done:
+            self.make_direct_maps()
+        if self.vectoridx_to_videoid is None:
+            self.vectoridx_to_videoid = {
+                self.videoid_to_vectoridx[videoid]: videoid
+                for videoid in self.videoid_to_vectoridx
+            }
+            assert len(self.vectoridx_to_videoid) \
+                == len(self.videoid_to_vectoridx)
+
+        src_modality = "text" if target_modality == "video" else "video"
+
+        query_hidden_states = []
+        vector_ids = []
+        for video_id in video_ids:
+            vector_id = self.videoid_to_vectoridx[video_id]
+            vector_ids.append(vector_id)
+            query_hidden_state = self.db[src_modality].reconstruct(vector_id)
+            query_hidden_states.append(query_hidden_state)
+        query_hidden_states = np.stack(query_hidden_states)
+
+        # MultilingualFaissDataset uses the following; not sure the reason.
+        # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
+        _, index = self.db[target_modality].search(
+            query_hidden_states, retri_factor)
+        outputs = []
+        for sample_idx, sample in enumerate(index):
+            cands = []
+            for vector_idx in sample:
+                if vector_idx >= 0:
+                    cands.append(
+                        self.vectoridx_to_videoid[vector_idx]
+                    )
+            outputs.append(cands)
+        return outputs

+ 246 - 0
examples/MMPT/mmpt/modules/vectorpool.py

@@ -0,0 +1,246 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import torch
+import os
+import numpy as np
+import pickle
+
+from . import retri
+from ..utils import get_local_rank
+
+
+class VectorPool(object):
+    """
+    Base class of retrieval space.
+    """
+
+    def __init__(self, config):
+        from transformers import AutoConfig
+        self.hidden_size = AutoConfig.from_pretrained(
+            config.dataset.bert_name).hidden_size
+        self.retriever_cls = getattr(retri, config.retriever_cls)
+
+    def __call__(self, sample, **kwargs):
+        raise NotImplementedError
+
+    def build_retriver(
+        self,
+        retriever_cls=None,
+        hidden_size=None,
+        centroids=512,
+        db_type="flatl2",
+        examples_per_cent_to_train=48
+    ):
+
+        """merge results from multiple gpus and return a retriver.."""
+        self.retriver = retriever_cls(
+            hidden_size, centroids, db_type, examples_per_cent_to_train)
+        return self.retriver
+
+    def __repr__(self):
+        if hasattr(self, "retriver"):
+            retriver_name = str(len(self.retriver))
+        else:
+            retriver_name = "no retriver field yet"
+        return self.__class__.__name__ \
+            + "(" + retriver_name + ")"
+
+
+class VideoVectorPool(VectorPool):
+    """
+    average clips of a video as video representation.
+    """
+    def __init__(self, config):
+        super().__init__(config)
+        self.build_retriver(self.retriever_cls, self.hidden_size)
+
+    def __call__(self, sample, subsampling, **kwargs):
+        hidden_states = (
+            sample["pooled_video"] + sample["pooled_text"]) / 2.
+        hidden_states = hidden_states.view(
+            -1, subsampling,
+            hidden_states.size(-1))
+        hidden_states = torch.mean(hidden_states, dim=1)
+        hidden_states = hidden_states.cpu().detach().numpy()
+        video_ids = []
+        for offset_idx, video_id in enumerate(sample["video_id"]):
+            if isinstance(video_id, tuple) and len(video_id) == 3:
+                # a sharded video_id.
+                video_id = video_id[0]
+            video_ids.append(video_id)
+        assert len(video_ids) == len(hidden_states)
+        self.retriver.add(
+            hidden_states.astype("float32"),
+            video_ids
+        )
+
+
+class DistributedVectorPool(VectorPool):
+    """
+    support sync of multiple gpus/nodes.
+    """
+    def __init__(self, config):
+        super().__init__(config)
+        self.out_dir = os.path.join(
+            config.fairseq.checkpoint.save_dir,
+            "retri")
+        os.makedirs(self.out_dir, exist_ok=True)
+        self.hidden_states = []
+        self.video_ids = []
+
+    def build_retriver(
+        self,
+        retriever_cls=None,
+        hidden_size=None,
+        centroids=4096,
+        db_type="flatl2",
+        examples_per_cent_to_train=48
+    ):
+        if retriever_cls is None:
+            retriever_cls = self.retriever_cls
+        if hidden_size is None:
+            hidden_size = self.hidden_size
+        """merge results from multiple gpus and return a retriver.."""
+        if torch.distributed.is_initialized():
+            self.save()
+            # sync saving.
+            torch.distributed.barrier()
+            world_size = torch.distributed.get_world_size()
+        else:
+            world_size = 1
+        self.retriver = retriever_cls(
+            hidden_size, centroids, db_type, examples_per_cent_to_train)
+        # each gpu process has its own retriever.
+        for local_rank in range(world_size):
+            if get_local_rank() == 0:
+                print("load local_rank", local_rank)
+            hidden_states, video_ids = self.load(local_rank)
+            hidden_states = hidden_states.astype("float32")
+            self.retriver.add(hidden_states, video_ids)
+        return self.retriver
+
+    def load(self, local_rank):
+        hidden_states = np.load(
+            os.path.join(
+                self.out_dir,
+                "hidden_state" + str(local_rank) + ".npy"
+            )
+        )
+
+        with open(
+            os.path.join(
+                self.out_dir, "video_id" + str(local_rank) + ".pkl"),
+                "rb") as fr:
+            video_ids = pickle.load(fr)
+        return hidden_states, video_ids
+
+    def save(self):
+        hidden_states = np.vstack(self.hidden_states)
+        assert len(hidden_states) == len(self.video_ids), "{}, {}".format(
+            len(hidden_states),
+            len(self.video_ids)
+        )
+        local_rank = torch.distributed.get_rank() \
+            if torch.distributed.is_initialized() else 0
+
+        np.save(
+            os.path.join(
+                self.out_dir,
+                "hidden_state" + str(local_rank) + ".npy"),
+            hidden_states)
+
+        with open(
+            os.path.join(
+                self.out_dir,
+                "video_id" + str(local_rank) + ".pkl"),
+                "wb") as fw:
+            pickle.dump(
+                self.video_ids,
+                fw,
+                protocol=pickle.HIGHEST_PROTOCOL
+            )
+
+
+class DistributedVideoVectorPool(DistributedVectorPool):
+    """
+    average clips of a video as video representation.
+    """
+    def __call__(self, sample, subsampling, **kwargs):
+        hidden_states = (
+            sample["pooled_video"] + sample["pooled_text"]) / 2.
+        hidden_states = hidden_states.view(
+            -1, subsampling,
+            hidden_states.size(-1))
+        hidden_states = torch.mean(hidden_states, dim=1)
+        hidden_states = hidden_states.cpu().detach().numpy()
+        video_ids = []
+        for offset_idx, video_id in enumerate(sample["video_id"]):
+            if isinstance(video_id, tuple) and len(video_id) == 3:
+                # a sharded video_id.
+                video_id = video_id[0]
+            video_ids.append(video_id)
+        assert len(video_ids) == len(hidden_states)
+        self.hidden_states.append(hidden_states)
+        self.video_ids.extend(video_ids)
+
+
+# ------------ the following are deprecated --------------
+
+class TextClipVectorPool(VectorPool):
+    def __init__(self, config):
+        from transformers import AutoConfig
+        hidden_size = AutoConfig.from_pretrained(
+            config.dataset.bert_name).hidden_size
+        retriever_cls = getattr(retri, config.retriever_cls)
+        self.build_retriver(retriever_cls, hidden_size)
+
+    def __call__(self, sample, **kwargs):
+        clip_meta = sample["clip_meta"].cpu()
+        assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5]))
+        text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]]
+
+        if hasattr(self, "retriver"):
+            # build_retriver is called.
+            self.retriver.add(
+                sample["pooled_text"].cpu().numpy().astype("float32"),
+                text_meta
+            )
+        else:
+            raise NotImplementedError
+
+
+class MMClipVectorPool(VectorPool):
+    """
+    Multimodal Clip-level vector pool.
+    """
+    def __init__(self, out_dir):
+        """use hidden_states to store `(video, text)`."""
+        """use video_ids to store `(video_id, start, end)`."""
+        super().__init__(out_dir)
+
+    def __call__(self, sample, **kwargs):
+        pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy()
+        pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy()
+
+        self.hidden_states.append(
+            np.concatenate([pooled_video, pooled_text], axis=1)
+        )
+
+        video_starts = sample["video_start"].cpu()
+        video_ends = sample["video_end"].cpu()
+        assert torch.all(torch.le(video_starts, video_ends))
+
+        text_starts = sample["text_start"].cpu()
+        text_ends = sample["text_end"].cpu()
+        assert torch.all(torch.le(text_starts, text_ends))
+        subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"])
+        video_ids = [video_id for video_id in sample["video_id"]
+                    for _ in range(subsample_size)
+        ]
+        for video_id, video_start, video_end, text_start, text_end in zip(
+                video_ids, video_starts, video_ends, text_starts, text_ends):
+            self.video_ids.append((
+                video_id,
+                (int(video_start), int(video_end)),
+                (int(text_start), int(text_end))
+            ))

+ 23 - 0
examples/MMPT/mmpt/processors/__init__.py

@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .processor import *
+
+from .how2processor import *
+from .how2retriprocessor import *
+
+from .dsprocessor import *
+
+try:
+    from .rawvideoprocessor import *
+    from .codecprocessor import *
+    from .webvidprocessor import *
+    from .expprocessor import *
+    from .exphow2processor import *
+    from .exphow2retriprocessor import *
+    from .expcodecprocessor import *
+    from .expfeatureencoder import *
+    from .expdsprocessor import *
+except ImportError:
+    pass

+ 242 - 0
examples/MMPT/mmpt/processors/dedupprocessor.py

@@ -0,0 +1,242 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import json
+import pickle
+from tqdm import tqdm
+import os
+import numpy as np
+
+
+class CaptionDedupProcessor(object):
+    """remove overlapping of caption sentences(clip).
+    Some statistics:
+    caption:
+    {'t_clip_len': 246.6448431320854,
+    'video_len': 281.09174795676245,
+    'clip_tps': 0.8841283727427481,
+    'video_tps': 0.7821156477732097,
+    'min_clip_len': 0.0,
+    'max_clip_len': 398.3,
+    'mean_clip_len': 3.196580003006861,
+    'num_clip': 77.15897706301081}
+
+    raw_caption:
+    {'t_clip_len': 238.95908778424115,
+    'video_len': 267.5914859862507,
+    'clip_tps': 2.4941363624267963,
+    'video_tps': 2.258989769647173,
+    'min_clip_len': 0.0,
+    'max_clip_len': 398.3,
+    'mean_clip_len': 3.0537954186814265,
+    'num_clip': 78.24986779481756}
+    """
+
+    def __init__(self, pkl_file):
+        with open(pkl_file, "rb") as fd:
+            self.data = pickle.load(fd)
+        self.stat = {
+            "t_clip_len": [],
+            "video_len": [],
+            "clip_tps": [],
+            "video_tps": [],
+            "clip_len": [],
+        }
+
+    def __call__(self):
+        for idx, video_id in enumerate(tqdm(self.data)):
+            caption = json.loads(self.data[video_id])
+            caption = self._dedup(caption)
+            if idx < 4096:  # for the first 4096 examples, compute the statistics.
+                self.save_stat(video_id, caption)
+            self.data[video_id] = json.dumps(caption)
+        self.print_stat()
+
+    def single(self, video_id):
+        caption = json.loads(self.data[video_id])
+        for clip_idx, (start, end, text) in enumerate(
+            zip(caption["start"], caption["end"], caption["text"])
+        ):
+            print(start, end, text)
+        print("@" * 100)
+        caption = self._dedup(caption)
+        for clip_idx, (start, end, text) in enumerate(
+            zip(caption["start"], caption["end"], caption["text"])
+        ):
+            print(start, end, text)
+        print("#" * 100)
+        self.save_stat(video_id, caption)
+        self.print_stat()
+
+    def finalize(self, tgt_fn):
+        with open(tgt_fn, "wb") as fw:
+            pickle.dump(self.data, fw, pickle.HIGHEST_PROTOCOL)
+
+    def save_stat(self, video_id, caption):
+        video_fn = os.path.join(
+            "data/feat/feat_how2_s3d", video_id + ".npy"
+        )
+        if os.path.isfile(video_fn):
+            with open(video_fn, "rb", 1) as fr:  # 24 is the buffer size. buffered
+                version = np.lib.format.read_magic(fr)
+                shape, fortran, dtype = np.lib.format._read_array_header(fr, version)
+                video_len = shape[0]
+
+            t_clip_len = 0.0
+            t_tokens = 0
+            for idx, (start, end, text) in enumerate(
+                zip(caption["start"], caption["end"], caption["text"])
+            ):
+                clip_len = (
+                    (end - max(caption["end"][idx - 1], start))
+                    if idx > 0
+                    else end - start
+                )
+                t_clip_len += clip_len
+                t_tokens += len(text.split(" "))
+                self.stat["clip_len"].append(clip_len)
+            self.stat["t_clip_len"].append(t_clip_len)
+            self.stat["video_len"].append(video_len)
+            self.stat["clip_tps"].append(t_tokens / t_clip_len)
+            self.stat["video_tps"].append(t_tokens / video_len)
+
+    def print_stat(self):
+        result = {
+            "t_clip_len": np.mean(self.stat["t_clip_len"]),
+            "video_len": np.mean(self.stat["video_len"]),
+            "clip_tps": np.mean(self.stat["clip_tps"]),
+            "video_tps": np.mean(self.stat["video_tps"]),
+            "min_clip_len": min(self.stat["clip_len"]),
+            "max_clip_len": max(self.stat["clip_len"]),
+            "mean_clip_len": np.mean(self.stat["clip_len"]),
+            "num_clip": len(self.stat["clip_len"]) / len(self.stat["video_tps"]),
+        }
+        print(result)
+
+    def _dedup(self, caption):
+        def random_merge(end_idx, start, end, text, starts, ends, texts):
+            if random.random() > 0.5:
+                # print(clip_idx, "[PARTIAL INTO PREV]", end_idx)
+                # overlapped part goes to the end of previous.
+                ends[-1] = max(ends[-1], start)  # ?
+                rest_text = text[end_idx:].strip()
+                if rest_text:
+                    starts.append(max(ends[-1], start))
+                    ends.append(max(end, starts[-1]))
+                    texts.append(rest_text)
+            else:  # goes to the beginning of the current.
+                # strip the previous.
+                left_text = texts[-1][:-end_idx].strip()
+                if left_text:
+                    # print(clip_idx, "[PREV PARTIAL INTO CUR]", end_idx)
+                    ends[-1] = min(ends[-1], start)
+                    texts[-1] = left_text
+                else:
+                    # print(clip_idx, "[PREV LEFT NOTHING ALL INTO CUR]", end_idx)
+                    starts.pop(-1)
+                    ends.pop(-1)
+                    texts.pop(-1)
+                starts.append(start)
+                ends.append(end)
+                texts.append(text)
+
+        starts, ends, texts = [], [], []
+        for clip_idx, (start, end, text) in enumerate(
+            zip(caption["start"], caption["end"], caption["text"])
+        ):
+            if not isinstance(text, str):
+                continue
+            text = text.replace("\n", " ").strip()
+            if len(text) == 0:
+                continue
+            starts.append(start)
+            ends.append(end)
+            texts.append(text)
+            break
+
+        for clip_idx, (start, end, text) in enumerate(
+            zip(
+                caption["start"][clip_idx + 1:],
+                caption["end"][clip_idx + 1:],
+                caption["text"][clip_idx + 1:],
+            )
+        ):
+            if not isinstance(text, str):
+                continue
+            text = text.replace("\n", " ").strip()
+            if len(text) == 0:
+                continue
+
+            # print(clip_idx, texts[-5:])
+            # print(clip_idx, start, end, text)
+            if texts[-1].endswith(text):  # subset of prev caption -> merge
+                # print(clip_idx, "[MERGE INTO PREV]")
+                ends[-1] = max(ends[-1], end)
+            elif text.startswith(texts[-1]):  # superset of prev caption -> merge
+                # print(clip_idx, "[PREV MERGE INTO CUR]")
+                texts[-1] = text
+                starts[-1] = min(starts[-1], start)
+                ends[-1] = max(ends[-1], end)
+            else:  # overlapping or non-overlapping.
+                for end_idx in range(1, len(text) + 1):
+                    if texts[-1].endswith(text[:end_idx]):
+                        random_merge(end_idx, start, end, text, starts, ends, texts)
+                        break
+                else:
+                    starts.append(start)
+                    ends.append(end)
+                    texts.append(text)
+
+            assert (ends[-1] + 0.001) >= starts[-1] and len(
+                texts[-1]
+            ) > 0, "{} {} {} <- {} {} {}, {} {} {}".format(
+                str(starts[-1]),
+                str(ends[-1]),
+                texts[-1],
+                caption["start"][clip_idx - 1],
+                caption["end"][clip_idx - 1],
+                caption["text"][clip_idx - 1],
+                str(start),
+                str(end),
+                text,
+            )
+
+        return {"start": starts, "end": ends, "text": texts}
+
+
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser(description="dedup how2 caption")
+    parser.add_argument('--how2dir', default="data/how2")
+    args = parser.parse_args()
+
+    raw_caption_json = os.path.join(args.how2dir, "raw_caption.json")
+    raw_caption_pickle = os.path.join(args.how2dir, "raw_caption.pkl")
+    raw_caption_dedup_pickle = os.path.join(args.how2dir, "raw_caption_dedup.pkl")
+
+    def convert_to_pickle(src_fn, tgt_fn):
+        with open(src_fn) as fd:
+            captions = json.load(fd)
+
+        for video_id in captions:
+            captions[video_id] = json.dumps(captions[video_id])
+
+        with open(tgt_fn, "wb") as fw:
+            pickle.dump(captions, fw, pickle.HIGHEST_PROTOCOL)
+
+    if not os.path.isfile(raw_caption_pickle):
+        convert_to_pickle(raw_caption_json, raw_caption_pickle)
+
+    deduper = CaptionDedupProcessor(raw_caption_pickle)
+    deduper()
+    deduper.finalize(raw_caption_dedup_pickle)
+
+    """
+    # demo
+    deduper = CaptionDedupProcessor("data/how2/raw_caption.pkl")
+    deduper.single("HfIeQ9pzL5U")
+    """

+ 848 - 0
examples/MMPT/mmpt/processors/dsprocessor.py

@@ -0,0 +1,848 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+"""
+Processors for all downstream (ds) tasks.
+"""
+
+import json
+import os
+import pickle
+import random
+import math
+import numpy as np
+import torch
+
+from collections import defaultdict
+
+from .processor import (
+    MetaProcessor,
+    VideoProcessor,
+    TextProcessor,
+    Aligner,
+    MMAttentionMask2DProcessor,
+)
+
+from .how2processor import TextGenerationProcessor
+
+
+# ------------- A General Aligner for all downstream tasks-----------------
+
+
+class DSAligner(Aligner):
+    """
+    Downstream (DS) aligner shared by all datasets.
+    """
+
+    def __call__(self, video_id, video_feature, text_feature, wps=0.7):
+        # random sample a starting sec for video.
+        video_start = 0
+        video_end = min(len(video_feature), self.max_video_len)
+        # the whole sequence is a single clip.
+        video_clips = {"start": [video_start], "end": [video_end]}
+
+        text_feature = {
+            "cap": [text_feature],
+            "start": [video_start],
+            "end": [len(text_feature) / wps],
+        }
+        text_clip_indexs = [0]
+
+        vfeats, vmasks = self._build_video_seq(
+            video_feature, video_clips
+        )
+        caps, cmasks = self._build_text_seq(
+            text_feature, text_clip_indexs
+        )
+
+        return {
+            "caps": caps,
+            "cmasks": cmasks,
+            "vfeats": vfeats,
+            "vmasks": vmasks,
+            "video_id": video_id,
+        }
+
+
+class NLGTextProcessor(TextProcessor):
+    """
+    Also return the original text as ref.
+    """
+    def __call__(self, text_id):
+        return super().__call__(text_id), text_id
+
+
+class DSNLGAligner(DSAligner):
+    """extend with the capability of 2d mask for generation."""
+    def __init__(self, config):
+        super().__init__(config)
+        self.attnmasker = MMAttentionMask2DProcessor()
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(
+            self.bert_name, use_fast=self.use_fast,
+            bos_token="[CLS]", eos_token="[SEP]"
+        )
+        self.tokenizer = tokenizer
+        self.bos_token_id = tokenizer.bos_token_id
+        self.eos_token_id = tokenizer.eos_token_id
+        self.textgen = TextGenerationProcessor(tokenizer)
+
+    def __call__(self, video_id, video_feature, text_feature):
+        output = super().__call__(video_id, video_feature, text_feature[0])
+        if self.split == "test":
+            # output.update({"ref": text_feature[1]})
+            output.update({"ref": self.tokenizer.decode(
+                output["caps"], skip_special_tokens=True)})
+            text_label = output["caps"]
+            cmasks = torch.BoolTensor([1] * text_label.size(0))
+            caps = torch.LongTensor([
+                self.cls_token_id,
+                self.sep_token_id,
+                self.bos_token_id])
+        else:
+            caps, text_label = self.textgen(output["caps"])
+            cmasks = output["cmasks"]
+
+        attention_mask = self.attnmasker(
+            output["vmasks"], cmasks, "textgen")
+
+        output.update({
+            "caps": caps,
+            "cmasks": cmasks,
+            "text_label": text_label,
+            "attention_mask": attention_mask,
+        })
+        return output
+
+
+# -------------------- MSRVTT ------------------------
+
+
+class MSRVTTMetaProcessor(MetaProcessor):
+    """MSRVTT dataset.
+    reference: `howto100m/msrvtt_dataloader.py`
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        import pandas as pd
+        data = pd.read_csv(self._get_split_path(config))
+        # TODO: add a text1ka flag.
+        if config.split == "train" \
+                and config.full_test_path is not None \
+                and config.jsfusion_path is not None:
+            # add testing videos from full_test_path not used by jfusion.
+            additional_data = pd.read_csv(config.full_test_path)
+            jsfusion_data = pd.read_csv(config.jsfusion_path)
+
+            for video_id in additional_data["video_id"]:
+                if video_id not in jsfusion_data["video_id"].values:
+                    data = data.append(
+                        {"video_id": video_id}, ignore_index=True)
+
+        if config.dup is not None and config.split == "train":
+            data = data.append([data] * (config.dup - 1), ignore_index=True)
+        self.data = data
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        """slightly modify with if condition to combine train/test."""
+        vid, sentence = None, None
+        vid = self.data["video_id"].values[idx]
+        if "sentence" in self.data:  # for testing.
+            sentence = self.data["sentence"].values[idx]
+        else:  # for training.
+            sentence = vid
+        return vid, sentence
+
+
+class MSRVTTTextProcessor(TextProcessor):
+    """MSRVTT dataset.
+    reference: `msrvtt_dataloader.py` `MSRVTT_TrainDataLoader`.
+    TODO (huxu): add max_words.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.sentences = None
+        if config.json_path is not None and config.split == "train":
+            with open(config.json_path) as fd:
+                self.data = json.load(fd)
+            self.sentences = defaultdict(list)
+            for s in self.data["sentences"]:
+                self.sentences[s["video_id"]].append(s["caption"])
+
+    def __call__(self, text_id):
+        if self.sentences is not None:
+            rind = random.randint(0, len(self.sentences[text_id]) - 1)
+            sentence = self.sentences[text_id][rind]
+        else:
+            sentence = text_id
+        caption = self.tokenizer(sentence, add_special_tokens=False)
+        return caption["input_ids"]
+
+
+class MSRVTTNLGTextProcessor(MSRVTTTextProcessor):
+    """TODO: change dsaligner and merge to avoid any NLG text processor."""
+    def __call__(self, text_id):
+        if self.sentences is not None:
+            rind = random.randint(0, len(self.sentences[text_id]) - 1)
+            sentence = self.sentences[text_id][rind]
+        else:
+            sentence = text_id
+        caption = self.tokenizer(sentence, add_special_tokens=False)
+        return caption["input_ids"], sentence
+
+
+class MSRVTTQAMetaProcessor(MetaProcessor):
+    """MSRVTT-QA: retrieval-based multi-choice QA from JSFusion dataset.
+    For simplicity, we use the train retrieval model.
+    reference: `https://github.com/yj-yu/lsmdc`
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        import pandas as pd
+        csv_data = pd.read_csv(self._get_split_path(config), sep="\t")
+        data = []
+        for video_id, a1, a2, a3, a4, a5, answer in zip(
+                csv_data["vid_key"].values,
+                csv_data["a1"].values,
+                csv_data["a2"].values,
+                csv_data["a3"].values,
+                csv_data["a4"].values,
+                csv_data["a5"].values,
+                csv_data["answer"].values):
+            video_id = video_id.replace("msr", "video")
+            data.append((video_id, (answer, [a1, a2, a3, a4, a5])))
+        self.data = data
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        return self.data[idx]
+
+
+class MSRVTTQATextProcessor(TextProcessor):
+    """MSRVTT-QA dataset.
+    text_ans is of format `(answer, [a1, a2, a3, a4, a5])`.
+    """
+
+    def __call__(self, text_ans):
+        for ans_idx, ans in enumerate(text_ans[1]):
+            if isinstance(ans, str):
+                text_ans[1][ans_idx] = self.tokenizer(ans, add_special_tokens=False)["input_ids"]
+        return text_ans
+
+
+class MSRVTTQAAligner(DSAligner):
+    """MSRVTT dataset.
+    similar to sample in how2.
+    we call __call__ multiple times.
+    """
+
+    def __call__(self, video_id, video_feature, text_feature, wps=0.7):
+        caps = []
+        cmasks = []
+        answer = text_feature[0]
+        for ans_idx, _text_feature in enumerate(text_feature[1]):
+            output = super().__call__(
+                video_id, video_feature, _text_feature, wps)
+            caps.append(output["caps"])
+            cmasks.append(output["cmasks"])
+        output.update({
+            "caps": torch.stack(caps),
+            "cmasks": torch.stack(cmasks),
+            "answers": torch.LongTensor([answer]),
+        })
+        return output
+
+
+# -------------------- Youcook -----------------------
+
+
+class YoucookMetaProcessor(MetaProcessor):
+    """Youcook dataset.
+    reference: `howto100m/youcook_dataloader.py`
+    note that the data can be different as the
+    (1) some videos already in Howto100m are removed.
+    (2) stop words are removed from caption
+    TODO (huxu): make a flag to load the original caption.
+    (see youcookii_annotations_trainval.json).
+
+    The max_video_len can be 264 and text can be 64 tokens.
+    In reality we may not need that long. see projects/task/youcook.yaml
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        vfeat_dir = config.vfeat_dir
+        print(self._get_split_path(config))
+        with open(self._get_split_path(config), "rb") as fd:
+            data = pickle.load(fd)
+            all_valid_video_ids = set(
+                [os.path.splitext(fn)[0] for fn in os.listdir(vfeat_dir)]
+            )
+            recs = []
+            video_ids = set()
+            valid_video_ids = set()
+            for rec in data:  # filter videos not available.
+                udl_idx = rec["id"].rindex("_")
+                video_id = rec["id"][:udl_idx]
+                video_ids.add(video_id)
+                if video_id in all_valid_video_ids:
+                    valid_video_ids.add(video_id)
+                    recs.append(rec)
+            print("total video_ids in .pkl", len(video_ids))
+            print("valid video_ids in .pkl", len(valid_video_ids))
+            print("please verify {train,val}_list.txt")
+            data = recs
+            self.data = data
+
+        with open(config.trainval_annotation) as fd:
+            self.youcook_annotation = json.load(fd)["database"]
+        if config.use_annotation_text is True:
+            print("using text in annotation.")
+            self.use_annotation_caption = True
+        else:
+            self.use_annotation_caption = False
+
+    def __getitem__(self, idx):
+        def _get_video_and_caption(rec):
+            vid = rec["id"]
+            udl_idx = vid.rindex("_")
+            video_id, clip_id = vid[:udl_idx], int(vid[udl_idx + 1:])
+            clip = self.youcook_annotation[video_id]["annotations"][clip_id]
+            start, end = clip["segment"]
+            if self.use_annotation_caption:
+                caption = clip["sentence"]
+            else:
+                caption = rec["caption"]
+            return (video_id, start, end), caption
+
+        rec = self.data[idx]
+        video_info, text_info = _get_video_and_caption(rec)
+        return video_info, text_info
+
+
+class YoucookVideoProcessor(VideoProcessor):
+    """video_fn is a tuple of (video_id, start, end) now."""
+
+    def __call__(self, video_fn):
+        video_id, start, end = video_fn
+        feat = np.load(os.path.join(self.vfeat_dir, video_id + ".npy"))
+        return feat[start:end]
+
+
+class YoucookNLGMetaProcessor(MetaProcessor):
+    """NLG uses the original split:
+    `train_list.txt` and `val_list.txt`
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        vfeat_dir = config.vfeat_dir
+        print(self._get_split_path(config))
+        with open(self._get_split_path(config)) as fd:
+            video_ids = [
+                line.strip().split("/")[1] for line in fd.readlines()]
+            print("total video_ids in train/val_list.txt", len(video_ids))
+
+            all_valid_video_ids = set(
+                [os.path.splitext(fn)[0] for fn in os.listdir(vfeat_dir)]
+            )
+            video_ids = [
+                video_id for video_id in video_ids
+                if video_id in all_valid_video_ids]
+
+            print("valid video_ids in train/val_list.txt", len(video_ids))
+        with open(config.trainval_annotation) as fd:
+            self.youcook_annotation = json.load(fd)["database"]
+
+        data = []
+        for video_id in video_ids:
+            for clip in self.youcook_annotation[video_id]["annotations"]:
+                start, end = clip["segment"]
+                caption = clip["sentence"]
+                data.append(((video_id, start, end), caption))
+        self.data = data
+
+    def __getitem__(self, idx):
+        return self.data[idx]
+
+
+# --------------------- CrossTask -------------------------
+
+class CrossTaskMetaProcessor(MetaProcessor):
+    def __init__(self, config):
+        super().__init__(config)
+        np.random.seed(0)  # deterministic random split.
+        task_vids = self._get_vids(
+            config.train_csv_path,
+            config.vfeat_dir,
+            config.annotation_path)
+
+        val_vids = self._get_vids(
+            config.val_csv_path,
+            config.vfeat_dir,
+            config.annotation_path)
+
+        # filter out those task and vids appear in val_vids.
+        task_vids = {
+            task: [
+                vid for vid in vids
+                if task not in val_vids or vid not in val_vids[task]]
+            for task, vids in task_vids.items()}
+
+        primary_info = self._read_task_info(config.primary_path)
+        test_tasks = set(primary_info['steps'].keys())
+
+        # if args.use_related:
+        related_info = self._read_task_info(config.related_path)
+        task_steps = {**primary_info['steps'], **related_info['steps']}
+        n_steps = {**primary_info['n_steps'], **related_info['n_steps']}
+        # else:
+        #     task_steps = primary_info['steps']
+        #     n_steps = primary_info['n_steps']
+        all_tasks = set(n_steps.keys())
+        # filter and keep task in primary or related.
+        task_vids = {
+            task: vids for task, vids in task_vids.items()
+            if task in all_tasks}
+        # vocab-by-step matrix (A) and vocab (M)
+        # (huxu): we do not use BoW.
+        # A, M = self._get_A(task_steps, share="words")
+
+        train_vids, test_vids = self._random_split(
+            task_vids, test_tasks, config.n_train)
+        print("train_num_videos", sum(len(vids) for vids in train_vids.values()))
+        print("test_num_videos", sum(len(vids) for vids in test_vids.values()))
+        # added by huxu to automatically determine the split.
+        split_map = {
+            "train": train_vids,
+            "valid": test_vids,
+            "test": test_vids
+        }
+        task_vids = split_map[config.split]
+
+        self.vids = []
+        for task, vids in task_vids.items():
+            self.vids.extend([(task, vid) for vid in vids])
+        self.task_steps = task_steps
+        self.n_steps = n_steps
+
+    def __getitem__(self, idx):
+        task, vid = self.vids[idx]
+        n_steps = self.n_steps[task]
+        steps = self.task_steps[task]
+        assert len(steps) == n_steps
+        return (task, vid, steps, n_steps), (task, vid, steps, n_steps)
+
+    def __len__(self):
+        return len(self.vids)
+
+    def _random_split(self, task_vids, test_tasks, n_train):
+        train_vids = {}
+        test_vids = {}
+        for task, vids in task_vids.items():
+            if task in test_tasks and len(vids) > n_train:
+                train_vids[task] = np.random.choice(
+                    vids, n_train, replace=False).tolist()
+                test_vids[task] = [
+                    vid for vid in vids if vid not in train_vids[task]]
+            else:
+                train_vids[task] = vids
+        return train_vids, test_vids
+
+    def _get_vids(self, path, vfeat_dir, annotation_path):
+        """refactored from
+        https://github.com/DmZhukov/CrossTask/blob/master/data.py
+        changes: add `vfeat_dir` to check if the video is available.
+        add `annotation_path` to check if the video is available.
+        """
+
+        task_vids = {}
+        with open(path, 'r') as f:
+            for line in f:
+                task, vid, url = line.strip().split(',')
+                # double check the video is available.
+                if not os.path.exists(
+                        os.path.join(vfeat_dir, vid + ".npy")):
+                    continue
+                # double check the annotation is available.
+                if not os.path.exists(os.path.join(
+                        annotation_path,
+                        task + "_" + vid + ".csv")):
+                    continue
+                if task not in task_vids:
+                    task_vids[task] = []
+                task_vids[task].append(vid)
+        return task_vids
+
+    def _read_task_info(self, path):
+        titles = {}
+        urls = {}
+        n_steps = {}
+        steps = {}
+        with open(path, 'r') as f:
+            idx = f.readline()
+            while idx != '':
+                idx = idx.strip()
+                titles[idx] = f.readline().strip()
+                urls[idx] = f.readline().strip()
+                n_steps[idx] = int(f.readline().strip())
+                steps[idx] = f.readline().strip().split(',')
+                next(f)
+                idx = f.readline()
+        return {
+            'title': titles,
+            'url': urls,
+            'n_steps': n_steps,
+            'steps': steps
+        }
+
+    def _get_A(self, task_steps, share="words"):
+        raise ValueError("running get_A is not allowed for BERT.")
+        """Step-to-component matrices."""
+        if share == 'words':
+            # share words
+            task_step_comps = {
+                task: [step.split(' ') for step in steps]
+                for task, steps in task_steps.items()}
+        elif share == 'task_words':
+            # share words within same task
+            task_step_comps = {
+                task: [[task+'_'+tok for tok in step.split(' ')] for step in steps]
+                for task, steps in task_steps.items()}
+        elif share == 'steps':
+            # share whole step descriptions
+            task_step_comps = {
+                task: [[step] for step in steps] for task, steps in task_steps.items()}
+        else:
+            # no sharing
+            task_step_comps = {
+                task: [[task+'_'+step] for step in steps]
+                for task, steps in task_steps.items()}
+        # BERT tokenizer here?
+        vocab = []
+        for task, steps in task_step_comps.items():
+            for step in steps:
+                vocab.extend(step)
+        vocab = {comp: m for m, comp in enumerate(set(vocab))}
+        M = len(vocab)
+        A = {}
+        for task, steps in task_step_comps.items():
+            K = len(steps)
+            a = torch.zeros(M, K)
+            for k, step in enumerate(steps):
+                a[[vocab[comp] for comp in step], k] = 1
+            a /= a.sum(dim=0)
+            A[task] = a
+        return A, M
+
+
+class CrossTaskVideoProcessor(VideoProcessor):
+    def __call__(self, video_fn):
+        task, vid, steps, n_steps = video_fn
+        video_fn = os.path.join(self.vfeat_dir, vid + ".npy")
+        feat = np.load(video_fn)
+        return feat
+
+
+class CrossTaskTextProcessor(TextProcessor):
+    def __call__(self, text_id):
+        task, vid, steps, n_steps = text_id
+        step_ids = []
+        for step_str in steps:
+            step_ids.append(
+                self.tokenizer(step_str, add_special_tokens=False)["input_ids"]
+            )
+        return step_ids
+
+
+class CrossTaskAligner(Aligner):
+    """
+    TODO: it's not clear yet the formulation of the task; finish this later.
+    """
+    def __init__(self, config):
+        super().__init__(config)
+        self.annotation_path = config.annotation_path
+        self.sliding_window = config.sliding_window
+        self.sliding_window_size = config.sliding_window_size
+
+    def __call__(self, video_id, video_feature, text_feature):
+        task, vid, steps, n_steps = video_id
+        annot_path = os.path.join(
+            self.annotation_path, task + '_' + vid + '.csv')
+        video_len = len(video_feature)
+
+        labels = torch.from_numpy(self._read_assignment(
+            video_len, n_steps, annot_path)).float()
+
+        vfeats, vmasks, targets = [], [], []
+        # sliding window on video features and targets.
+        for window_start in range(0, video_len, self.sliding_window):
+            video_start = 0
+            video_end = min(video_len - window_start, self.sliding_window_size)
+            video_clip = {"start": [video_start], "end": [video_end]}
+
+            vfeat, vmask = self._build_video_seq(
+                video_feature[window_start: window_start + video_end],
+                video_clip
+            )
+
+            target = labels[window_start: window_start + video_end]
+            assert len(vfeat) >= len(target), "{},{}".format(len(vfeat), len(target))
+            # TODO: randomly drop all zero targets for training ?
+            # if self.split == "train" and target.sum() == 0:
+            #     continue
+            vfeats.append(vfeat)
+            vmasks.append(vmask)
+            targets.append(target)
+
+            if (video_len - window_start) <= self.sliding_window_size:
+                break
+
+        vfeats = torch.stack(vfeats)
+        vmasks = torch.stack(vmasks)
+        targets = torch.cat(targets, dim=0)
+
+        caps, cmasks = [], []
+        for step in text_feature:
+            step_text_feature = {"start": [0], "end": [1], "cap": [step]}
+            step_text_clip_index = [0]
+            cap, cmask = self._build_text_seq(
+                step_text_feature, step_text_clip_index
+            )
+            caps.append(cap)
+            cmasks.append(cmask)
+        caps = torch.stack(caps)
+        cmasks = torch.stack(cmasks)
+
+        return {
+            "caps": caps,
+            "cmasks": cmasks,
+            "vfeats": vfeats,  # X for original code.
+            "vmasks": vmasks,
+            "targets": targets,
+            "video_id": vid,
+            "task": task,
+            "video_len": video_len  # for later checking.
+        }
+
+    def _read_assignment(self, T, K, path):
+        """
+        refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py
+        Howto interpret contraints on loss that is going to be minimized:
+        lambd is a big number;
+        self.lambd * C is a big number for all valid position (csv stores invalids)
+
+        def forward(self, O, Y, C):
+            return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum()
+
+        This will load the csv file and fill-in the step col from start to end rows.
+        """
+
+        Y = np.zeros([T, K], dtype=np.uint8)
+        with open(path, 'r') as f:
+            for line in f:
+                step, start, end = line.strip().split(',')
+                start = int(math.floor(float(start)))
+                end = int(math.ceil(float(end)))
+                step = int(step) - 1
+                Y[start:end, step] = 1
+        return Y
+
+
+# --------------------- COIN -------------------------
+
+class MetaTextBinarizer(Aligner):
+    def __call__(self, text_feature):
+        text_feature = {
+            "cap": [text_feature],
+            "start": [0.],
+            "end": [100.],
+        }
+        text_clip_indexs = [0]
+
+        caps, cmasks = self._build_text_seq(
+            text_feature, text_clip_indexs
+        )
+        return {"caps": caps, "cmasks": cmasks}
+
+
+class COINActionSegmentationMetaProcessor(MetaProcessor):
+    split_map = {
+        "train": "training",
+        "valid": "testing",
+        "test": "testing",
+    }
+
+    def __init__(self, config):
+        super().__init__(config)
+        with open(self._get_split_path(config)) as fr:
+            database = json.load(fr)["database"]
+        id2label = {}
+        data = []
+        # filter the data by split.
+        for video_id, rec in database.items():
+            # always use testing to determine label_set
+            if rec["subset"] == "testing":
+                for segment in rec["annotation"]:
+                    id2label[int(segment["id"])] = segment["label"]
+        # text_labels is used for ZS setting
+        self.text_labels = ["none"] * len(id2label)
+        for label_id in id2label:
+            self.text_labels[label_id-1] = id2label[label_id]
+
+        id2label[0] = "O"
+        print("num of labels", len(id2label))
+
+        for video_id, rec in database.items():
+            if not os.path.isfile(os.path.join(config.vfeat_dir, video_id + ".npy")):
+                continue
+            if rec["subset"] == COINActionSegmentationMetaProcessor.split_map[self.split]:
+                starts, ends, labels = [], [], []
+                for segment in rec["annotation"]:
+                    start, end = segment["segment"]
+                    label = int(segment["id"])
+                    starts.append(start)
+                    ends.append(end)
+                    labels.append(label)
+                data.append(
+                    (video_id, {"start": starts, "end": ends, "label": labels}))
+        self.data = data
+
+    def meta_text_labels(self, config):
+        from transformers import default_data_collator
+        from ..utils import get_local_rank
+
+        text_processor = TextProcessor(config)
+        binarizer = MetaTextBinarizer(config)
+        # TODO: add prompts to .yaml.
+        text_labels = [label for label in self.text_labels]
+
+        if get_local_rank() == 0:
+            print(text_labels)
+
+        outputs = []
+        for text_label in text_labels:
+            text_feature = text_processor(text_label)
+            outputs.append(binarizer(text_feature))
+        return default_data_collator(outputs)
+
+    def __getitem__(self, idx):
+        return self.data[idx]
+
+
+class COINActionSegmentationTextProcessor(TextProcessor):
+    def __call__(self, text_label):
+        return text_label
+
+
+class COINActionSegmentationAligner(Aligner):
+    def __init__(self, config):
+        super().__init__(config)
+        self.sliding_window = config.sliding_window
+        self.sliding_window_size = config.sliding_window_size
+
+    def __call__(self, video_id, video_feature, text_feature):
+        starts, ends, label_ids = text_feature["start"], text_feature["end"], text_feature["label"]
+        # sliding window.
+        video_len = len(video_feature)
+
+        vfeats, vmasks, targets = [], [], []
+        # sliding window on video features and targets.
+        for window_start in range(0, video_len, self.sliding_window):
+            video_start = 0
+            video_end = min(video_len - window_start, self.sliding_window_size)
+            video_clip = {"start": [video_start], "end": [video_end]}
+            vfeat, vmask = self._build_video_seq(
+                video_feature[window_start: window_start + video_end],
+                video_clip
+            )
+            # covers video length only.
+            target = torch.full_like(vmask, -100, dtype=torch.long)
+            target[vmask] = 0
+            for start, end, label_id in zip(starts, ends, label_ids):
+                if (window_start < end) and (start < (window_start + video_end)):
+                    start_offset = max(0, math.floor(start) - window_start)
+                    end_offset = min(video_end, math.ceil(end) - window_start)
+                    target[start_offset:end_offset] = label_id
+            vfeats.append(vfeat)
+            vmasks.append(vmask)
+            targets.append(target)
+            if (video_len - window_start) <= self.sliding_window_size:
+                break
+
+        vfeats = torch.stack(vfeats)
+        vmasks = torch.stack(vmasks)
+        targets = torch.stack(targets)
+        video_targets = torch.full((video_len,), 0)
+        for start, end, label_id in zip(starts, ends, label_ids):
+            start_offset = max(0, math.floor(start))
+            end_offset = min(video_len, math.ceil(end))
+            video_targets[start_offset:end_offset] = label_id
+
+        caps = torch.LongTensor(
+            [[self.cls_token_id, self.sep_token_id,
+              self.pad_token_id, self.sep_token_id]],
+            ).repeat(vfeats.size(0), 1)
+        cmasks = torch.BoolTensor(
+            [[0, 1, 0, 1]]  # pad are valid for attention.
+            ).repeat(vfeats.size(0), 1)
+        return {
+            "caps": caps,
+            "cmasks": cmasks,
+            "vfeats": vfeats,  # X for original code.
+            "vmasks": vmasks,
+            "targets": targets,
+            "video_id": video_id,
+            "video_len": video_len,  # for later checking.
+            "video_targets": video_targets
+        }
+
+
+class DiDeMoMetaProcessor(MetaProcessor):
+    """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+    https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
+    """
+    def __init__(self, config):
+        super().__init__(config)
+
+        assert "test" in self._get_split_path(config), "DiDeMo only supports zero-shot testing for now."
+
+        with open(self._get_split_path(config)) as data_file:
+            json_data = json.load(data_file)
+
+        data = []
+        for record in json_data:
+            data.append((record["video"], record["description"]))
+        self.data = data
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        return self.data[idx]
+
+
+class DiDeMoTextProcessor(TextProcessor):
+    """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
+    https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
+    """
+
+    def __call__(self, text):
+        return self.tokenizer(text, add_special_tokens=False)["input_ids"]
+
+
+class DiDeMoAligner(DSAligner):
+    """
+    check video length.
+    """
+
+    def __call__(self, video_id, video_feature, text_feature):
+        # print(video_feature.shape[0])
+        return super().__call__(video_id, video_feature, text_feature)

+ 887 - 0
examples/MMPT/mmpt/processors/how2processor.py

@@ -0,0 +1,887 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+
+import torch
+import math
+import pickle
+import random
+import os
+import numpy as np
+
+from collections import deque
+from typing import Optional, Tuple, List
+from .processor import (
+    Processor,
+    MetaProcessor,
+    TextProcessor,
+    Aligner,
+    MMAttentionMask2DProcessor
+)
+
+from ..utils import ShardedTensor
+
+
+class How2MetaProcessor(MetaProcessor):
+    def __init__(self, config):
+        super().__init__(config)
+        path = self._get_split_path(config)
+        with open(path) as fd:
+            self.data = [line.strip() for line in fd]
+
+    def __getitem__(self, idx):
+        video_id = self.data[idx]
+        return video_id, video_id
+
+
+class ShardedHow2MetaProcessor(How2MetaProcessor):
+    def __init__(self, config):
+        super().__init__(config)
+        self.split = str(config.split)
+        self.vfeat_dir = config.vfeat_dir
+        self._init_shard()
+
+    def _init_shard(self):
+        if self.split == "train":
+            meta_fn = os.path.join(self.vfeat_dir, "train" + "_meta.pkl")
+            with open(meta_fn, "rb") as fr:
+                meta = pickle.load(fr)
+        elif self.split == "valid":
+            meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
+            with open(meta_fn, "rb") as fr:
+                meta = pickle.load(fr)
+        elif self.split == "test":
+            print("use how2 val as test.")
+            meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
+            with open(meta_fn, "rb") as fr:
+                meta = pickle.load(fr)
+        else:
+            raise ValueError("unsupported for MetaProcessor:", self.split)
+        video_id_to_shard = {}
+        for shard_id in meta:
+            for video_idx, video_id in enumerate(meta[shard_id]):
+                video_id_to_shard[video_id] = (shard_id, video_idx)
+        self.video_id_to_shard = video_id_to_shard
+
+    def __getitem__(self, idx):
+        video_id, video_id = super().__getitem__(idx)
+        shard_id, shard_idx = self.video_id_to_shard[video_id]
+        meta = (video_id, idx, shard_id, shard_idx)
+        return meta, meta
+
+
+class ShardedVideoProcessor(Processor):
+    """
+    mmaped shards of numpy video features.
+    """
+
+    def __init__(self, config):
+        self.split = str(config.split)
+        self.vfeat_dir = config.vfeat_dir
+
+    def __call__(self, video_id):
+        _, _, shard_id, video_idx = video_id
+        if self.split == "train":
+            shard = ShardedTensor.load(
+                os.path.join(self.vfeat_dir, "train" + "_" + str(shard_id)),
+                "r"
+            )
+        elif self.split == "valid":
+            shard = ShardedTensor.load(
+                os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
+                "r"
+            )
+        elif self.split == "test":
+            shard = ShardedTensor.load(
+                os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
+                "r"
+            )
+        else:
+            raise ValueError("unknown split", self.split)
+        feat = shard[video_idx]
+        return feat
+
+
+class ShardedTextProcessor(Processor):
+    def __init__(self, config):
+        self.tfeat_dir = str(config.tfeat_dir)
+        self.split = str(config.split)
+
+    def __call__(self, video_id):
+        _, _, shard_id, shard_idx = video_id
+        if self.split == "train":
+            target_path = self.tfeat_dir + "train" + "_" + str(shard_id)
+        elif self.split == "valid":
+            target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
+        elif self.split == "test":
+            target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
+        else:
+            raise ValueError("unknown split", self.split)
+
+        startend = ShardedTensor.load(
+            target_path + ".startends", "r")[shard_idx]
+        cap_ids = ShardedTensor.load(
+            target_path + ".caps_ids", "r")[shard_idx]
+        cap = []
+        for clip_idx in range(len(cap_ids)):
+            clip = cap_ids[clip_idx]
+            cap.append(clip[clip != -1].tolist())
+        start, end = startend[:, 0].tolist(), startend[:, 1].tolist()
+        return {"start": start, "end": end, "cap": cap}
+
+
+class FixedLenAligner(Aligner):
+    """
+    In the model we assume text is on the left (closer to BERT formulation)
+    and video is on the right.
+    We fix the total length of text + video.
+    max_video_len is in number of secs.
+    max_text_len is in number of tokens.
+
+    special tokens formats:
+    we use the format [CLS] [SEP] text tokens [SEP] [PAD] ...
+    [CLS] will be splitted out into:
+    [CLS] video tokens [SEP] text tokens [SEP] [PAD] ...
+    token_type_ids will be generated by the model (for now).
+    0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+    | first sequence    | second sequence |
+    so each sequence owns a [SEP] token for no-ops.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.text_clip_sampler = TextClipSamplingProcessor(
+            self.max_len - self.max_video_len - 3
+        )
+        """
+        decide subsampling:
+        `config.subsampling` will change batch_size in trainer.
+        `config.clip_per_video` (used by RetriTask) doesn't
+            change batch_size in trainer.
+        """
+        subsampling = config.subsampling \
+            if config.subsampling is not None else None
+        if config.clip_per_video is not None:
+            subsampling = config.clip_per_video
+        self.subsampling = subsampling
+
+    def _get_text_maxlen(self):
+        # use max text len
+        return self.text_clip_sampler.max_text_len
+
+    def __call__(self, video_id, video_feature, text_feature):
+        from transformers import default_data_collator
+        video_idx = video_id[1]
+        if self.subsampling is not None and self.subsampling >= 1:
+            batch = []
+            for _ in range(self.subsampling):
+                centerclip_idx = random.randint(
+                                    0, len(text_feature["start"]) - 1)
+                batch.append(
+                    self.sampling(
+                        video_idx,
+                        video_feature,
+                        text_feature,
+                        centerclip_idx,
+                        self._get_text_maxlen()
+                    ))
+            batch = self.batch_post_processing(batch, video_feature)
+            batch = default_data_collator(batch)
+        else:
+            raise ValueError(
+                "dataset.subsampling must be >= 1 for efficient video loading.")
+            batch = self.sampling(video_idx, video_feature, text_feature)
+            batch = self.batch_post_processing(batch, video_feature)
+
+        batch["video_id"] = video_id if isinstance(video_id, str) \
+            else video_id[0]
+        # e2e: make sure frame ids is into tensor.
+        assert torch.is_tensor(batch["vfeats"])
+        return batch
+
+    def sampling(
+        self,
+        video_idx,
+        video_feature,
+        text_feature,
+        centerclip_idx=None,
+        sampled_max_text_len=None,
+    ):
+        text_clip_indexs = self.text_clip_sampler(
+            text_feature, centerclip_idx,
+            sampled_max_text_len
+        )
+        if isinstance(video_feature, np.ndarray):
+            video_len = len(video_feature)
+        else:
+            video_len = math.ceil(text_feature["end"][-1])
+
+        video_end = min(
+            math.ceil(text_feature["end"][text_clip_indexs[-1]]),
+            video_len
+        )
+        video_start = max(
+            min(
+                math.floor(text_feature["start"][text_clip_indexs[0]]),
+                video_end),
+            0
+        )
+
+        video_clips = {"start": [video_start], "end": [video_end]}
+
+        # tensorize.
+        vfeats, vmasks = self._build_video_seq(
+            video_feature, video_clips
+        )
+        caps, cmasks = self._build_text_seq(
+            text_feature, text_clip_indexs
+        )
+
+        text_start = text_clip_indexs[0]
+        text_end = text_clip_indexs[-1] + 1
+
+        return {
+            "caps": caps,
+            "cmasks": cmasks,
+            "vfeats": vfeats,
+            "vmasks": vmasks,
+            "video_start": video_start,
+            "video_end": video_end,
+            "text_start": text_start,
+            "text_end": text_end,
+        }
+
+
+class VariedLenAligner(FixedLenAligner):
+    def __init__(self, config):
+        super().__init__(config)
+        self.sampled_min_len = config.sampled_min_len
+        self.sampled_max_len = config.sampled_max_len
+
+    def _get_text_maxlen(self):
+        return random.randint(self.sampled_min_len, self.sampled_max_len)
+
+
+class StartClipAligner(VariedLenAligner):
+    def sampling(
+        self,
+        video_idx,
+        video_feature,
+        text_feature,
+        centerclip_idx=None,
+        sampled_max_text_len=None,
+    ):
+        return super().sampling(
+            video_idx, video_feature, text_feature, 0)
+
+
+class OverlappedAligner(VariedLenAligner):
+    """video clip and text clip has overlappings
+    but may not be the same start/end."""
+    def __init__(self, config):
+        super().__init__(config)
+        self.sampled_video_min_len = config.sampled_video_min_len
+        self.sampled_video_max_len = config.sampled_video_max_len
+
+        self.video_clip_sampler = VideoClipSamplingProcessor()
+
+    def _get_video_maxlen(self):
+        return random.randint(
+            self.sampled_video_min_len, self.sampled_video_max_len)
+
+    def sampling(
+        self,
+        video_idx,
+        video_feature,
+        text_feature,
+        centerclip_idx=None,
+        sampled_max_text_len=None,
+    ):
+        text_clip_indexs = self.text_clip_sampler(
+            text_feature, centerclip_idx,
+            sampled_max_text_len
+        )
+        if isinstance(video_feature, np.ndarray):
+            video_len = len(video_feature)
+        else:
+            video_len = math.ceil(text_feature["end"][-1])
+        low = math.floor(text_feature["start"][text_clip_indexs[0]])
+        high = math.ceil(text_feature["end"][text_clip_indexs[-1]])
+        if low < high:
+            center = random.randint(low, high)
+        else:
+            center = int((low + high) // 2)
+        center = max(0, min(video_feature.shape[0] - 1, center))
+
+        assert 0 <= center < video_feature.shape[0]
+
+        video_clips = self.video_clip_sampler(
+            video_len, self._get_video_maxlen(), center
+        )
+        video_start = video_clips["start"][0]
+        video_end = video_clips["end"][0]
+
+        # tensorize.
+        vfeats, vmasks = self._build_video_seq(
+            video_feature, video_clips
+        )
+        caps, cmasks = self._build_text_seq(
+            text_feature, text_clip_indexs
+        )
+
+        text_start = text_clip_indexs[0]
+        text_end = text_clip_indexs[-1] + 1
+
+        return {
+            "caps": caps,
+            "cmasks": cmasks,
+            "vfeats": vfeats,
+            "vmasks": vmasks,
+            "video_start": video_start,
+            "video_end": video_end,
+            "text_start": text_start,
+            "text_end": text_end,
+        }
+
+
+class MFMMLMAligner(FixedLenAligner):
+    """
+    `FixedLenAligner` with Masked Language Model and Masked Frame Model.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        keep_prob = config.keep_prob if config.keep_prob is not None else 1.0
+        self.text_clip_sampler = TextClipSamplingProcessor(
+            self.max_len - self.max_video_len - 3, keep_prob
+        )
+        self.sampled_min_len = config.sampled_min_len
+        self.sampled_max_len = config.sampled_max_len
+        self.masked_token_sampler = TextMaskingProcessor(config)
+        self.mm_type = config.mm_type \
+            if config.mm_type is not None else "full"
+        self.attnmasker = MMAttentionMask2DProcessor() \
+            if self.mm_type == "textgen" else None
+        self.masked_frame_sampler = FrameMaskingProcessor(config)
+        self.lazy_vfeat_mask = (
+            False if config.lazy_vfeat_mask is None else config.lazy_vfeat_mask
+        )
+        self.mm_prob = config.mm_prob if config.mm_prob is not None else 0.
+
+    def __call__(self, video_id, video_feature, text_feature):
+        from transformers import default_data_collator
+        if self.subsampling is not None and self.subsampling > 1:
+            batch = []
+            for _ in range(self.subsampling):
+                centerclip_idx = random.randint(
+                                    0, len(text_feature["start"]) - 1)
+                sampled_max_text_len = random.randint(
+                    self.sampled_min_len, self.sampled_max_len
+                )
+                batch.append(
+                    self.sampling(
+                        video_id,
+                        video_feature,
+                        text_feature,
+                        centerclip_idx,
+                        sampled_max_text_len,
+                    )
+                )
+            batch = self.batch_post_processing(batch, video_feature)
+            batch = default_data_collator(batch)
+        else:
+            batch = self.sampling(video_id, video_feature, text_feature)
+            batch = self.batch_post_processing(batch, video_feature)
+        batch["video_id"] = video_id if isinstance(video_id, str) \
+            else video_id[0]
+        return batch
+
+    def sampling(
+        self,
+        video_id,
+        video_feature,
+        text_feature,
+        centerclip_idx=None,
+        sampled_max_text_len=None,
+    ):
+        output = FixedLenAligner.sampling(self,
+            video_id, video_feature, text_feature,
+            centerclip_idx, sampled_max_text_len)
+
+        masking_text, masking_video = None, None
+        if random.random() < self.mm_prob:
+            if random.random() > 0.5:
+                masking_text, masking_video = self.mm_type, "no"
+            else:
+                masking_text, masking_video = "no", "full"
+        video_feats = output["vfeats"] if not self.lazy_vfeat_mask else None
+        video_label = self.masked_frame_sampler(
+            output["vmasks"], masking_video, vfeats=video_feats)
+        caps, text_label = self.masked_token_sampler(
+            output["caps"], masking_text)
+
+        output.update({
+            "caps": caps,
+            "video_label": video_label,
+            "text_label": text_label,
+        })
+
+        if self.attnmasker is not None:
+            attention_mask = self.attnmasker(
+                output["vmasks"], output["cmasks"], masking_text)
+            output.update({
+                "attention_mask": attention_mask
+            })
+        return output
+
+
+class FrameMaskingProcessor(Processor):
+    def __init__(self, config):
+        self.mfm_probability = 0.15
+        if config.mfm_probability is not None:
+            self.mfm_probability = config.mfm_probability
+
+    def __call__(self, vmasks, modality_masking=None, vfeats=None):
+        """
+        We perform lazy masking to save data transfer time.
+        It only generates video_labels by default and MFM model
+        will do actualy masking.
+        Return: `video_label` is a binary mask.
+        """
+        video_label = vmasks.clone()
+        if modality_masking is not None:
+            if modality_masking == "full":
+                probability_matrix = torch.full(video_label.shape, 1.)
+            elif modality_masking == "no":
+                probability_matrix = torch.full(video_label.shape, 0.)
+            elif modality_masking == "inverse":
+                probability_matrix = torch.full(
+                    video_label.shape, 1. - self.mfm_probability)
+            else:
+                raise ValueError("unknown modality masking.", modality_masking)
+        else:
+            probability_matrix = torch.full(
+                video_label.shape, self.mfm_probability)
+        masked_indices = torch.bernoulli(probability_matrix).bool()
+        # We only compute loss on masked tokens
+        video_label[~masked_indices] = 0
+        if vfeats is not None:
+            vfeats[video_label, :] = 0.0
+        return video_label
+
+
+class TextGenerationProcessor(Processor):
+    def __init__(self, tokenizer):
+        self.bos_token_id = tokenizer.bos_token_id
+        self.pad_token_id = tokenizer.pad_token_id
+
+    def __call__(self, inputs):
+        labels = inputs.clone()
+        # [CLS] [SEP] for video
+        labels[:2] = -100
+        # keep [SEP] for text.
+        pad_mask = labels == self.pad_token_id
+        labels[pad_mask] = -100
+        inputs[2:] = torch.cat([
+            torch.LongTensor([self.bos_token_id]),
+            inputs[2:-1]])
+        inputs[pad_mask] = self.pad_token_id
+        assert len(inputs) == len(labels)
+        return inputs, labels
+
+
+class TextMaskingProcessor(Processor):
+    def __init__(self, config):
+        """this function is borrowed from
+        `transformers/data/data_collator.DataCollatorForLanguageModeling`"""
+        self.mlm_probability = 0.15
+        if config.mlm_probability is not None:
+            self.mlm_probability = config.mlm_probability
+        self.bert_name = config.bert_name
+        # [CLS] is used as bos_token and [SEP] is used as eos_token.
+        # https://huggingface.co/transformers/master/model_doc/bertgeneration.html
+        from transformers import AutoTokenizer
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            self.bert_name, bos_token="[CLS]", eos_token="[SEP]")
+        self.textgen = TextGenerationProcessor(self.tokenizer)
+
+    def __call__(
+        self, inputs: torch.Tensor,
+        modality_masking=None,
+        special_tokens_mask: Optional[torch.Tensor] = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        expand modality_masking into
+            None: traditional bert masking.
+            "no": no masking.
+            "full": all [MASK] token for generation.
+            "gen": autoregressive generation.
+        """
+        """
+        Prepare masked tokens inputs/labels for masked language modeling:
+        80% MASK, 10% random, 10% original.
+        """
+        labels = inputs.clone()
+        # We sample a few tokens in each sequence for MLM training
+        # (with probability `self.mlm_probability`)
+        if modality_masking is not None:
+            if modality_masking == "full":
+                probability_matrix = torch.full(labels.shape, 1.)
+            elif modality_masking == "no":
+                probability_matrix = torch.full(labels.shape, 0.)
+            elif modality_masking.startswith("textgen"):
+                # [CLS] [SEP] <s> ...
+                inputs, labels = self.textgen(inputs)
+                if "mask" not in modality_masking:
+                    return inputs, labels
+                inputs = self.mask_input(inputs, special_tokens_mask)
+                return inputs, labels
+            elif modality_masking == "mask":
+                inputs = self.mask_input(inputs, special_tokens_mask)
+                labels = torch.full(inputs.shape, -100)
+                return inputs, labels
+            elif modality_masking == "inverse":
+                probability_matrix = torch.full(labels.shape, 1. - self.mlm_probability)
+            else:
+                raise ValueError("unknown modality masking.", modality_masking)
+        else:
+            probability_matrix = torch.full(labels.shape, self.mlm_probability)
+
+        if special_tokens_mask is None:
+            special_tokens_mask = self.get_special_tokens_mask(
+                labels.tolist(), already_has_special_tokens=True
+            )
+            special_tokens_mask = torch.tensor(
+                special_tokens_mask, dtype=torch.bool)
+        else:
+            special_tokens_mask = special_tokens_mask.bool()
+
+        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
+        masked_indices = torch.bernoulli(probability_matrix).bool()
+        labels[~masked_indices] = -100  # We only compute loss on masked tokens
+
+        # 80% of the time,
+        # we replace masked input tokens with tokenizer.mask_token ([MASK])
+        indices_replaced = (
+            torch.bernoulli(
+                torch.full(labels.shape, 0.8)).bool() & masked_indices
+        )
+        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
+            self.tokenizer.mask_token
+        )
+
+        # 10% of the time, we replace masked input tokens with random word
+        indices_random = (
+            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
+            & masked_indices
+            & ~indices_replaced
+        )
+        random_words = torch.randint(
+            len(self.tokenizer), labels.shape, dtype=torch.long
+        )
+        inputs[indices_random] = random_words[indices_random]
+
+        # The rest of the time (10% of the time) we keep the masked input
+        # tokens unchanged
+        return inputs, labels
+
+    def mask_input(self, inputs, special_tokens_mask=None):
+        # the following is new with masked autoregressive.
+        probability_matrix = torch.full(
+            inputs.shape, self.mlm_probability)
+        if special_tokens_mask is None:
+            special_tokens_mask = self.get_special_tokens_mask(
+                inputs.tolist(), already_has_special_tokens=True
+            )
+            special_tokens_mask = torch.tensor(
+                special_tokens_mask, dtype=torch.bool)
+        else:
+            special_tokens_mask = special_tokens_mask.bool()
+        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
+        masked_indices = torch.bernoulli(probability_matrix).bool()
+        indices_replaced = (
+            torch.bernoulli(
+                torch.full(inputs.shape, 0.8)).bool() & masked_indices
+        )
+        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
+            self.tokenizer.mask_token
+        )
+
+        # 10% of the time, we replace masked input tokens with random word
+        indices_random = (
+            torch.bernoulli(torch.full(inputs.shape, 0.5)).bool()
+            & masked_indices
+            & ~indices_replaced
+        )
+        random_words = torch.randint(
+            len(self.tokenizer), inputs.shape, dtype=torch.long
+        )
+        inputs[indices_random] = random_words[indices_random]
+        return inputs
+
+    def get_special_tokens_mask(
+        self, token_ids_0: List[int],
+        token_ids_1: Optional[List[int]] = None,
+        already_has_special_tokens: bool = False
+    ) -> List[int]:
+        """
+        Note: the version from transformers do not consider pad
+        as special tokens.
+        """
+
+        if already_has_special_tokens:
+            if token_ids_1 is not None:
+                raise ValueError(
+                    "You should not supply a second sequence if"
+                    "the provided sequence of "
+                    "ids is already formated with special tokens "
+                    "for the model."
+                )
+            return list(map(lambda x: 1 if x in [
+                self.tokenizer.sep_token_id,
+                self.tokenizer.cls_token_id,
+                self.tokenizer.pad_token_id] else 0, token_ids_0))
+
+        if token_ids_1 is not None:
+            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+        return [1] + ([0] * len(token_ids_0)) + [1]
+
+
+class TextClipSamplingProcessor(Processor):
+    def __init__(self, max_text_len, keep_prob=1.0):
+        self.max_text_len = max_text_len
+        self.max_video_len = 256  # always hold.
+        self.keep_prob = keep_prob
+
+    def __call__(
+        self,
+        text_feature,
+        centerclip_idx=None,
+        sampled_max_text_len=None,
+        sampled_max_video_len=None,
+    ):
+        # Let's use all caps for now and see if 256 can cover all of them.
+        if sampled_max_text_len is not None:
+            max_text_len = sampled_max_text_len
+        else:
+            max_text_len = self.max_text_len
+        if sampled_max_video_len is not None:
+            max_video_len = sampled_max_video_len
+        else:
+            max_video_len = self.max_video_len
+
+        t_num_clips = len(text_feature["start"])
+
+        if centerclip_idx is None:
+            centerclip_idx = random.randint(0, t_num_clips - 1)
+
+        start_idx, end_idx = centerclip_idx, centerclip_idx + 1
+        text_clip_indexs = deque()
+        text_clip_indexs.append(start_idx)
+        text_len = len(text_feature["cap"][start_idx])
+
+        video_len = max(
+            0,
+            text_feature["end"][start_idx]
+            - text_feature["start"][start_idx],
+        )
+
+        while (
+            (start_idx > 0 or end_idx < t_num_clips)
+            and text_len < max_text_len
+            and video_len < max_video_len
+        ):
+            if random.random() > 0.5 and end_idx < t_num_clips:
+                # skip the next one?
+                if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
+                    end_idx = end_idx + 1
+                text_clip_indexs.append(end_idx)
+                text_len += len(text_feature["cap"][end_idx])
+                end_idx += 1
+            elif start_idx > 0:
+                if random.random() > self.keep_prob and (start_idx - 1) > 0:
+                    start_idx = start_idx - 1
+                start_idx -= 1
+                text_clip_indexs.insert(0, start_idx)
+                text_len += len(text_feature["cap"][start_idx])
+            else:
+                if end_idx < t_num_clips:
+                    if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
+                        end_idx = end_idx + 1
+                    text_clip_indexs.append(end_idx)
+                    text_len += len(text_feature["cap"][end_idx])
+                    end_idx += 1
+                else:
+                    return text_clip_indexs
+            video_len = max(
+                0,
+                text_feature["end"][text_clip_indexs[-1]]
+                - text_feature["start"][text_clip_indexs[0]],
+            )
+        return text_clip_indexs
+
+
+class VideoClipSamplingProcessor(Processor):
+    def __call__(self, video_len, max_video_len, center):
+        """
+        `video_len`: length of the video.
+        `max_video_len`: maximum video tokens allowd in a sequence.
+        `center`: initial starting index.
+        """
+        assert center >= 0 and center < video_len
+        t_clip_len = 0
+        start, end = center, center
+        while (start > 0 or end < video_len) and t_clip_len < max_video_len:
+            # decide the direction to grow.
+            if start <= 0:
+                end += 1
+            elif end >= video_len:
+                start -= 1
+            elif random.random() > 0.5:
+                end += 1
+            else:
+                start -= 1
+            t_clip_len += 1
+        return {"start": [start], "end": [end]}
+
+
+class How2MILNCEAligner(FixedLenAligner):
+    """reference: `antoine77340/MIL-NCE_HowTo100M/video_loader.py`"""
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_candidates = 4
+        self.min_time = 5.0
+        self.num_sec = 3.2
+        # self.num_sec = self.num_frames / float(self.fps)  num_frames=16 / fps = 5
+        # self.num_frames = 16
+
+    def sampling(
+        self,
+        video_id,
+        video_feature,
+        text_feature,
+        centerclip_idx=None,  # will be ignored.
+        sampled_max_text_len=None  # will be ignored.
+    ):
+        text, start, end = self._get_text(text_feature)
+        video = self._get_video(video_feature, start, end)
+
+        vfeats = torch.zeros((self.max_video_len, video_feature.shape[1]))
+        vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
+        vfeats[: video.shape[0]] = torch.from_numpy(np.array(video))
+        vmasks[: video.shape[0]] = 1
+
+        caps, cmasks = [], []
+        for words in text:
+            cap, cmask = self._build_text_seq(text_feature, words)
+            caps.append(cap)
+            cmasks.append(cmask)
+        caps = torch.stack(caps)
+        cmasks = torch.stack(cmasks)
+        # video of shape: (video_len)
+        # text of shape (num_candidates, max_text_len)
+
+        return {
+            "caps": caps,
+            "cmasks": cmasks,
+            "vfeats": vfeats,
+            "vmasks": vmasks,
+            # "video_id": video_id,
+        }
+
+    def _get_video(self, video_feature, start, end):
+        start_seek = random.randint(start, int(max(start, end - self.num_sec)))
+        # duration = self.num_sec + 0.1
+        return video_feature[start_seek : int(start_seek + self.num_sec)]
+
+    def _get_text(self, cap):
+        ind = random.randint(0, len(cap["start"]) - 1)
+        if self.num_candidates == 1:
+            words = [ind]
+        else:
+            words = []
+            cap_start = self._find_nearest_candidates(cap, ind)
+            for i in range(self.num_candidates):
+                words.append([max(0, min(len(cap["cap"]) - 1, cap_start + i))])
+
+        start, end = cap["start"][ind], cap["end"][ind]
+        # TODO: May need to be improved for edge cases.
+        # expand the min time.
+        if end - start < self.min_time:
+            diff = self.min_time - end + start
+            start = max(0, start - diff / 2)
+            end = start + self.min_time
+        return words, int(start), int(end)
+
+    def _find_nearest_candidates(self, caption, ind):
+        """find the range of the clips."""
+        start, end = ind, ind
+        #diff = caption["end"][end] - caption["start"][start]
+        n_candidate = 1
+        while n_candidate < self.num_candidates:
+            # the first clip
+            if start == 0:
+                return 0
+            # we add () in the following condition to fix the bug.
+            elif end == (len(caption["start"]) - 1):
+                return start - (self.num_candidates - n_candidate)
+            elif (caption["end"][end] - caption["start"][start - 1]) < (
+                caption["end"][end + 1] - caption["start"][start]
+            ):
+                start -= 1
+            else:
+                end += 1
+            n_candidate += 1
+        return start
+
+
+class PKLJSONStrTextProcessor(TextProcessor):
+    """`caption.json` from howto100m are preprocessed as a
+    dict `[video_id, json_str]`.
+    Json parsing tokenization are conducted on-the-fly and cached into dict.
+    """
+
+    def __init__(self, config, max_clip_text_len=96):
+        print("[Warning] PKLJSONStrTextProcessor is slow for num_workers > 0.")
+        self.caption_pkl_path = str(config.caption_pkl_path)
+        with open(self.caption_pkl_path, "rb") as fd:
+            self.data = pickle.load(fd)
+        self.max_clip_text_len = max_clip_text_len
+        from transformers import AutoTokenizer
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            str(config.bert_name), use_fast=config.use_fast
+        )
+
+    def __call__(self, video_id):
+        caption = self.data[video_id]
+        if isinstance(caption, str):
+            import json
+            caption = json.loads(caption)
+            cap = []
+            for clip_idx, text_clip in enumerate(caption["text"]):
+                clip_ids = []
+                if isinstance(text_clip, str):
+                    clip_ids = self.tokenizer(
+                        text_clip[: self.max_clip_text_len],
+                        add_special_tokens=False
+                    )["input_ids"]
+                cap.append(clip_ids)
+            caption["cap"] = cap
+            caption.pop("text")  # save space.
+            self.data[video_id] = caption
+        return caption

+ 100 - 0
examples/MMPT/mmpt/processors/how2retriprocessor.py

@@ -0,0 +1,100 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .how2processor import (
+    ShardedHow2MetaProcessor,
+    ShardedVideoProcessor,
+    ShardedTextProcessor,
+    VariedLenAligner,
+    OverlappedAligner
+)
+
+
+class ShardedHow2VideoRetriMetaProcessor(ShardedHow2MetaProcessor):
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_video_per_batch = config.num_video_per_batch
+        self.cands = [
+            self.data[batch_offset:batch_offset + self.num_video_per_batch]
+            for batch_offset in
+            range(0, (len(self.data) // (8 * self.num_video_per_batch)) * 8 * self.num_video_per_batch, self.num_video_per_batch)]
+
+    def __len__(self):
+        return len(self.cands)
+
+    def set_candidates(self, cands):
+        # no changes on num of batches.
+        print(len(self.cands), "->", len(cands))
+        # assert len(self.cands) == len(cands)
+        self.cands = cands
+
+    def __getitem__(self, idx):
+        video_ids = self.cands[idx]
+        assert isinstance(video_ids, list)
+        sharded_video_idxs = []
+        for video_id in video_ids:
+            shard_id, video_idx = self.video_id_to_shard[video_id]
+            sharded_video_idxs.append((video_id, -1, shard_id, video_idx))
+        return sharded_video_idxs, sharded_video_idxs
+
+
+class ShardedVideoRetriVideoProcessor(ShardedVideoProcessor):
+    """In retrival case the video_id
+    is a list of tuples: `(shard_id, video_idx)` ."""
+
+    def __call__(self, sharded_video_idxs):
+        assert isinstance(sharded_video_idxs, list)
+        cand_feats = []
+        for shared_video_idx in sharded_video_idxs:
+            feat = super().__call__(shared_video_idx)
+            cand_feats.append(feat)
+        return cand_feats
+
+
+class ShardedVideoRetriTextProcessor(ShardedTextProcessor):
+    """In retrival case the video_id
+    is a list of tuples: `(shard_id, video_idx)` ."""
+
+    def __call__(self, sharded_video_idxs):
+        assert isinstance(sharded_video_idxs, list)
+        cand_caps = []
+        for shared_video_idx in sharded_video_idxs:
+            caps = super().__call__(shared_video_idx)
+            cand_caps.append(caps)
+        return cand_caps
+
+
+class VideoRetriAligner(VariedLenAligner):
+    # Retritask will trim dim-0.
+    def __call__(self, sharded_video_idxs, video_features, text_features):
+        from transformers import default_data_collator
+        batch, video_ids = [], []
+        for video_id, video_feature, text_feature in \
+                zip(sharded_video_idxs, video_features, text_features):
+            sub_batch = super().__call__(video_id, video_feature, text_feature)
+            batch.append(sub_batch)
+            if isinstance(video_id, tuple):
+                video_id = video_id[0]
+            video_ids.append(video_id)
+        batch = default_data_collator(batch)
+        batch["video_id"] = video_ids
+        return batch
+
+
+class VideoRetriOverlappedAligner(OverlappedAligner):
+    # Retritask will trim dim-0.
+    def __call__(self, sharded_video_idxs, video_features, text_features):
+        from transformers import default_data_collator
+        batch, video_ids = [], []
+        for video_id, video_feature, text_feature in \
+                zip(sharded_video_idxs, video_features, text_features):
+            sub_batch = super().__call__(video_id, video_feature, text_feature)
+            batch.append(sub_batch)
+            if isinstance(video_id, tuple):
+                video_id = video_id[0]
+            video_ids.append(video_id)
+        batch = default_data_collator(batch)
+        batch["video_id"] = video_ids
+        return batch

+ 336 - 0
examples/MMPT/mmpt/processors/models/s3dg.py

@@ -0,0 +1,336 @@
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Contains a PyTorch definition for Gated Separable 3D network (S3D-G)
+with a text module for computing joint text-video embedding from raw text
+and video input. The following code will enable you to load the HowTo100M
+pretrained S3D Text-Video model from:
+  A. Miech, J.-B. Alayrac, L. Smaira, I. Laptev, J. Sivic and A. Zisserman,
+  End-to-End Learning of Visual Representations from Uncurated Instructional Videos.
+  https://arxiv.org/abs/1912.06430.
+
+S3D-G was proposed by:
+  S. Xie, C. Sun, J. Huang, Z. Tu and K. Murphy,
+  Rethinking Spatiotemporal Feature Learning For Video Understanding.
+  https://arxiv.org/abs/1712.04851.
+  Tensorflow code: https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
+
+The S3D architecture was slightly modified with a space to depth trick for TPU
+optimization.
+"""
+
+import torch as th
+import torch.nn.functional as F
+import torch.nn as nn
+import os
+import numpy as np
+import re
+
+
+class InceptionBlock(nn.Module):
+    def __init__(
+        self,
+        input_dim,
+        num_outputs_0_0a,
+        num_outputs_1_0a,
+        num_outputs_1_0b,
+        num_outputs_2_0a,
+        num_outputs_2_0b,
+        num_outputs_3_0b,
+        gating=True,
+    ):
+        super(InceptionBlock, self).__init__()
+        self.conv_b0 = STConv3D(input_dim, num_outputs_0_0a, [1, 1, 1])
+        self.conv_b1_a = STConv3D(input_dim, num_outputs_1_0a, [1, 1, 1])
+        self.conv_b1_b = STConv3D(
+            num_outputs_1_0a, num_outputs_1_0b, [3, 3, 3], padding=1, separable=True
+        )
+        self.conv_b2_a = STConv3D(input_dim, num_outputs_2_0a, [1, 1, 1])
+        self.conv_b2_b = STConv3D(
+            num_outputs_2_0a, num_outputs_2_0b, [3, 3, 3], padding=1, separable=True
+        )
+        self.maxpool_b3 = th.nn.MaxPool3d((3, 3, 3), stride=1, padding=1)
+        self.conv_b3_b = STConv3D(input_dim, num_outputs_3_0b, [1, 1, 1])
+        self.gating = gating
+        self.output_dim = (
+            num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b + num_outputs_3_0b
+        )
+        if gating:
+            self.gating_b0 = SelfGating(num_outputs_0_0a)
+            self.gating_b1 = SelfGating(num_outputs_1_0b)
+            self.gating_b2 = SelfGating(num_outputs_2_0b)
+            self.gating_b3 = SelfGating(num_outputs_3_0b)
+
+    def forward(self, input):
+        """Inception block
+      """
+        b0 = self.conv_b0(input)
+        b1 = self.conv_b1_a(input)
+        b1 = self.conv_b1_b(b1)
+        b2 = self.conv_b2_a(input)
+        b2 = self.conv_b2_b(b2)
+        b3 = self.maxpool_b3(input)
+        b3 = self.conv_b3_b(b3)
+        if self.gating:
+            b0 = self.gating_b0(b0)
+            b1 = self.gating_b1(b1)
+            b2 = self.gating_b2(b2)
+            b3 = self.gating_b3(b3)
+        return th.cat((b0, b1, b2, b3), dim=1)
+
+
+class SelfGating(nn.Module):
+    def __init__(self, input_dim):
+        super(SelfGating, self).__init__()
+        self.fc = nn.Linear(input_dim, input_dim)
+
+    def forward(self, input_tensor):
+        """Feature gating as used in S3D-G.
+      """
+        spatiotemporal_average = th.mean(input_tensor, dim=[2, 3, 4])
+        weights = self.fc(spatiotemporal_average)
+        weights = th.sigmoid(weights)
+        return weights[:, :, None, None, None] * input_tensor
+
+
+class STConv3D(nn.Module):
+    def __init__(
+        self, input_dim, output_dim, kernel_size, stride=1, padding=0, separable=False
+    ):
+        super(STConv3D, self).__init__()
+        self.separable = separable
+        self.relu = nn.ReLU(inplace=True)
+        assert len(kernel_size) == 3
+        if separable and kernel_size[0] != 1:
+            spatial_kernel_size = [1, kernel_size[1], kernel_size[2]]
+            temporal_kernel_size = [kernel_size[0], 1, 1]
+            if isinstance(stride, list) and len(stride) == 3:
+                spatial_stride = [1, stride[1], stride[2]]
+                temporal_stride = [stride[0], 1, 1]
+            else:
+                spatial_stride = [1, stride, stride]
+                temporal_stride = [stride, 1, 1]
+            if isinstance(padding, list) and len(padding) == 3:
+                spatial_padding = [0, padding[1], padding[2]]
+                temporal_padding = [padding[0], 0, 0]
+            else:
+                spatial_padding = [0, padding, padding]
+                temporal_padding = [padding, 0, 0]
+        if separable:
+            self.conv1 = nn.Conv3d(
+                input_dim,
+                output_dim,
+                kernel_size=spatial_kernel_size,
+                stride=spatial_stride,
+                padding=spatial_padding,
+                bias=False,
+            )
+            self.bn1 = nn.BatchNorm3d(output_dim)
+            self.conv2 = nn.Conv3d(
+                output_dim,
+                output_dim,
+                kernel_size=temporal_kernel_size,
+                stride=temporal_stride,
+                padding=temporal_padding,
+                bias=False,
+            )
+            self.bn2 = nn.BatchNorm3d(output_dim)
+        else:
+            self.conv1 = nn.Conv3d(
+                input_dim,
+                output_dim,
+                kernel_size=kernel_size,
+                stride=stride,
+                padding=padding,
+                bias=False,
+            )
+            self.bn1 = nn.BatchNorm3d(output_dim)
+
+    def forward(self, input):
+        out = self.relu(self.bn1(self.conv1(input)))
+        if self.separable:
+            out = self.relu(self.bn2(self.conv2(out)))
+        return out
+
+
+class MaxPool3dTFPadding(th.nn.Module):
+    def __init__(self, kernel_size, stride=None, padding="SAME"):
+        super(MaxPool3dTFPadding, self).__init__()
+        if padding == "SAME":
+            padding_shape = self._get_padding_shape(kernel_size, stride)
+            self.padding_shape = padding_shape
+            self.pad = th.nn.ConstantPad3d(padding_shape, 0)
+        self.pool = th.nn.MaxPool3d(kernel_size, stride, ceil_mode=True)
+
+    def _get_padding_shape(self, filter_shape, stride):
+        def _pad_top_bottom(filter_dim, stride_val):
+            pad_along = max(filter_dim - stride_val, 0)
+            pad_top = pad_along // 2
+            pad_bottom = pad_along - pad_top
+            return pad_top, pad_bottom
+
+        padding_shape = []
+        for filter_dim, stride_val in zip(filter_shape, stride):
+            pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val)
+            padding_shape.append(pad_top)
+            padding_shape.append(pad_bottom)
+        depth_top = padding_shape.pop(0)
+        depth_bottom = padding_shape.pop(0)
+        padding_shape.append(depth_top)
+        padding_shape.append(depth_bottom)
+        return tuple(padding_shape)
+
+    def forward(self, inp):
+        inp = self.pad(inp)
+        out = self.pool(inp)
+        return out
+
+
+class Sentence_Embedding(nn.Module):
+    def __init__(
+        self,
+        embd_dim,
+        num_embeddings=66250,
+        word_embedding_dim=300,
+        token_to_word_path="dict.npy",
+        max_words=16,
+        output_dim=2048,
+    ):
+        super(Sentence_Embedding, self).__init__()
+        self.word_embd = nn.Embedding(num_embeddings, word_embedding_dim)
+        self.fc1 = nn.Linear(word_embedding_dim, output_dim)
+        self.fc2 = nn.Linear(output_dim, embd_dim)
+        self.word_to_token = {}
+        self.max_words = max_words
+        token_to_word = np.load(token_to_word_path)
+        for i, t in enumerate(token_to_word):
+            self.word_to_token[t] = i + 1
+
+    def _zero_pad_tensor_token(self, tensor, size):
+        if len(tensor) >= size:
+            return tensor[:size]
+        else:
+            zero = th.zeros(size - len(tensor)).long()
+            return th.cat((tensor, zero), dim=0)
+
+    def _split_text(self, sentence):
+        w = re.findall(r"[\w']+", str(sentence))
+        return w
+
+    def _words_to_token(self, words):
+        words = [
+            self.word_to_token[word] for word in words if word in self.word_to_token
+        ]
+        if words:
+            we = self._zero_pad_tensor_token(th.LongTensor(words), self.max_words)
+            return we
+        else:
+            return th.zeros(self.max_words).long()
+
+    def _words_to_ids(self, x):
+        split_x = [self._words_to_token(self._split_text(sent.lower())) for sent in x]
+        return th.stack(split_x, dim=0)
+
+    def forward(self, x):
+        x = self._words_to_ids(x)
+        x = self.word_embd(x)
+        x = F.relu(self.fc1(x))
+        x = th.max(x, dim=1)[0]
+        x = self.fc2(x)
+        return {'text_embedding': x}
+
+
+class S3D(nn.Module):
+    def __init__(self, dict_path, num_classes=512, gating=True, space_to_depth=True):
+        super(S3D, self).__init__()
+        self.num_classes = num_classes
+        self.gating = gating
+        self.space_to_depth = space_to_depth
+        if space_to_depth:
+            self.conv1 = STConv3D(
+                24, 64, [2, 4, 4], stride=1, padding=(1, 2, 2), separable=False
+            )
+        else:
+            self.conv1 = STConv3D(
+                3, 64, [3, 7, 7], stride=2, padding=(1, 3, 3), separable=False
+            )
+        self.conv_2b = STConv3D(64, 64, [1, 1, 1], separable=False)
+        self.conv_2c = STConv3D(64, 192, [3, 3, 3], padding=1, separable=True)
+        self.gating = SelfGating(192)
+        self.maxpool_2a = MaxPool3dTFPadding(
+            kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
+        )
+        self.maxpool_3a = MaxPool3dTFPadding(
+            kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
+        )
+        self.mixed_3b = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
+        self.mixed_3c = InceptionBlock(
+            self.mixed_3b.output_dim, 128, 128, 192, 32, 96, 64
+        )
+        self.maxpool_4a = MaxPool3dTFPadding(
+            kernel_size=(3, 3, 3), stride=(2, 2, 2), padding="SAME"
+        )
+        self.mixed_4b = InceptionBlock(
+            self.mixed_3c.output_dim, 192, 96, 208, 16, 48, 64
+        )
+        self.mixed_4c = InceptionBlock(
+            self.mixed_4b.output_dim, 160, 112, 224, 24, 64, 64
+        )
+        self.mixed_4d = InceptionBlock(
+            self.mixed_4c.output_dim, 128, 128, 256, 24, 64, 64
+        )
+        self.mixed_4e = InceptionBlock(
+            self.mixed_4d.output_dim, 112, 144, 288, 32, 64, 64
+        )
+        self.mixed_4f = InceptionBlock(
+            self.mixed_4e.output_dim, 256, 160, 320, 32, 128, 128
+        )
+        self.maxpool_5a = self.maxPool3d_5a_2x2 = MaxPool3dTFPadding(
+            kernel_size=(2, 2, 2), stride=(2, 2, 2), padding="SAME"
+        )
+        self.mixed_5b = InceptionBlock(
+            self.mixed_4f.output_dim, 256, 160, 320, 32, 128, 128
+        )
+        self.mixed_5c = InceptionBlock(
+            self.mixed_5b.output_dim, 384, 192, 384, 48, 128, 128
+        )
+        self.fc = nn.Linear(self.mixed_5c.output_dim, num_classes)
+        self.text_module = Sentence_Embedding(num_classes,
+            token_to_word_path=dict_path)
+
+    def _space_to_depth(self, input):
+        """3D space to depth trick for TPU optimization.
+      """
+        B, C, T, H, W = input.shape
+        input = input.view(B, C, T // 2, 2, H // 2, 2, W // 2, 2)
+        input = input.permute(0, 3, 5, 7, 1, 2, 4, 6)
+        input = input.contiguous().view(B, 8 * C, T // 2, H // 2, W // 2)
+        return input
+
+    def forward(self, inputs):
+        """Defines the S3DG base architecture."""
+        if self.space_to_depth:
+            inputs = self._space_to_depth(inputs)
+        net = self.conv1(inputs)
+        if self.space_to_depth:
+            # we need to replicate 'SAME' tensorflow padding
+            net = net[:, :, 1:, 1:, 1:]
+        net = self.maxpool_2a(net)
+        net = self.conv_2b(net)
+        net = self.conv_2c(net)
+        if self.gating:
+            net = self.gating(net)
+        net = self.maxpool_3a(net)
+        net = self.mixed_3b(net)
+        net = self.mixed_3c(net)
+        net = self.maxpool_4a(net)
+        net = self.mixed_4b(net)
+        net = self.mixed_4c(net)
+        net = self.mixed_4d(net)
+        net = self.mixed_4e(net)
+        net = self.mixed_4f(net)
+        net = self.maxpool_5a(net)
+        net = self.mixed_5b(net)
+        net = self.mixed_5c(net)
+        net = th.mean(net, dim=[2, 3, 4])
+        return {'video_embedding': self.fc(net), 'mixed_5c': net}

+ 274 - 0
examples/MMPT/mmpt/processors/processor.py

@@ -0,0 +1,274 @@
+# Copyright (c) Facebook, Inc. All Rights Reserved
+
+import numpy as np
+import os
+import torch
+
+
+class Processor(object):
+    """
+    A generic processor for video (codec, feature etc.) and text.
+    """
+
+    def __call__(self, **kwargs):
+        raise NotImplementedError
+
+
+class MetaProcessor(Processor):
+    """
+    A meta processor is expected to load the metadata of a dataset:
+        (e.g., video_ids, or captions).
+    You must implement the `__getitem__` (meta datasets are rather diverse.).
+    """
+
+    def __init__(self, config):
+        self.split = config.split
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        raise NotImplementedError
+
+    def _get_split_path(self, config):
+        splits = {
+            "train": config.train_path,
+            "valid": config.val_path,
+            "test": config.test_path,
+        }
+        if config.split is not None:
+            return splits[config.split]
+        return config.train_path
+
+
+class TextProcessor(Processor):
+    """
+    A generic Text processor: rename this as `withTokenizer`.
+    tokenize a string of text on-the-fly.
+    Warning: mostly used for end tasks.
+        (on-the-fly tokenization is slow for how2.)
+    TODO(huxu): move this class as a subclass.
+    """
+
+    def __init__(self, config):
+        self.bert_name = str(config.bert_name)
+        self.use_fast = config.use_fast
+        from transformers import AutoTokenizer
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            self.bert_name, use_fast=self.use_fast
+        )
+
+    def __call__(self, text_id):
+        caption = self.tokenizer(text_id, add_special_tokens=False)
+        return caption["input_ids"]
+
+
+class VideoProcessor(Processor):
+    """
+    A generic video processor: load a numpy video tokens by default.
+    """
+
+    def __init__(self, config):
+        self.vfeat_dir = config.vfeat_dir
+
+    def __call__(self, video_fn):
+        if isinstance(video_fn, tuple):
+            video_fn = video_fn[0]
+        assert isinstance(video_fn, str)
+        video_fn = os.path.join(self.vfeat_dir, video_fn + ".npy")
+        feat = np.load(video_fn)
+        return feat
+
+
+class Aligner(object):
+    """
+    An alignprocessor align video and text and output a dict of tensors (for a model).
+    """
+    def __init__(self, config):
+        """__init__ needs to be light weight for more workers/threads."""
+        self.split = config.split
+        self.max_video_len = config.max_video_len
+        self.max_len = config.max_len
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(
+            str(config.bert_name), use_fast=config.use_fast
+        )
+        self.cls_token_id = tokenizer.cls_token_id
+        self.sep_token_id = tokenizer.sep_token_id
+        self.pad_token_id = tokenizer.pad_token_id
+        self.mask_token_id = tokenizer.mask_token_id
+
+    def __call__(self, video_id, video_feature, text_feature):
+        raise NotImplementedError
+
+    def _build_video_seq(self, video_feature, video_clips=None):
+        """
+        `video_feature`: available video tokens.
+        `video_clips`: video clip sequence to build.
+        """
+        if not isinstance(video_feature, np.ndarray):
+            raise ValueError(
+                "unsupported type of video_feature", type(video_feature)
+            )
+
+        if video_clips is None:
+            # this is borrowed from DSAligner
+            video_start = 0
+            video_end = min(len(video_feature), self.max_video_len)
+            # the whole sequence is a single clip.
+            video_clips = {"start": [video_start], "end": [video_end]}
+
+        vfeats = np.zeros(
+            (self.max_video_len, video_feature.shape[1]), dtype=np.float32
+        )
+        vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
+        video_len = 0
+        for start, end in zip(video_clips["start"], video_clips["end"]):
+            clip_len = min(self.max_video_len - video_len, (end - start))
+            if clip_len > 0:
+                vfeats[video_len: video_len + clip_len] = video_feature[
+                    start: start + clip_len
+                ]
+                vmasks[video_len: video_len + clip_len] = 1
+                video_len += clip_len
+        vfeats = torch.from_numpy(vfeats)
+
+        return vfeats, vmasks
+
+    def _build_text_seq(self, text_feature, text_clip_indexs=None):
+        """
+        `text_feature`: all available clips.
+        `text_clip_indexes`: clip sequence to build.
+        """
+        if text_clip_indexs is None:
+            text_clip_indexs = [0]
+
+        full_caps = []
+        if isinstance(text_feature, dict):
+            for clip_idx in text_clip_indexs:
+                full_caps.extend(text_feature["cap"][clip_idx])
+        else:
+            full_caps = text_feature
+        max_text_len = self.max_len - self.max_video_len - 3
+        full_caps = full_caps[:max_text_len]
+        full_caps = (
+            [self.cls_token_id, self.sep_token_id] + full_caps + [self.sep_token_id]
+        )
+        text_pad_len = self.max_len - len(full_caps) - self.max_video_len
+        padded_full_caps = full_caps + [self.pad_token_id] * text_pad_len
+        caps = torch.LongTensor(padded_full_caps)
+        cmasks = torch.zeros((len(padded_full_caps),), dtype=torch.bool)
+        cmasks[: len(full_caps)] = 1
+
+        return caps, cmasks
+
+    def batch_post_processing(self, batch, video_feature):
+        return batch
+
+
+class MMAttentionMask2DProcessor(Processor):
+    """text generation requires 2d mask
+    that is harder to generate by GPU at this stage."""
+
+    def __call__(self, vmask, cmask, mtype):
+        if mtype == "textgen":
+            return self._build_textgeneration_mask(vmask, cmask)
+        elif mtype == "videogen":
+            return self._build_videogeneration_mask(vmask, cmask)
+        else:
+            return self._build_mm_mask(vmask, cmask)
+
+    def _build_mm_mask(self, vmask, cmask):
+        mask_1d = torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)
+        return mask_1d[None, :].repeat(mask_1d.size(0), 1)
+
+    def _build_videogeneration_mask(self, vmask, cmask):
+        # cls_mask is only about text otherwise it will leak generation.
+        cls_text_mask = torch.cat([
+            # [CLS]
+            torch.ones(
+                (1,), dtype=torch.bool, device=cmask.device),
+            # video tokens and [SEP] for video.
+            torch.zeros(
+                (vmask.size(0) + 1,), dtype=torch.bool, device=cmask.device),
+            cmask[2:]
+            ], dim=0)
+
+        # concat horizontially.
+        video_len = int(vmask.sum())
+        video_masks = torch.cat([
+            # [CLS]
+            torch.ones(
+                (video_len, 1), dtype=torch.bool, device=cmask.device
+            ),
+            torch.tril(
+                torch.ones(
+                    (video_len, video_len),
+                    dtype=torch.bool, device=cmask.device)),
+            # video_padding
+            torch.zeros(
+                (video_len, vmask.size(0) - video_len),
+                dtype=torch.bool, device=cmask.device
+            ),
+            # [SEP] for video (unused).
+            torch.zeros(
+                (video_len, 1), dtype=torch.bool, device=cmask.device
+            ),
+            cmask[2:].unsqueeze(0).repeat(video_len, 1)
+            ], dim=1)
+
+        text_masks = cls_text_mask[None, :].repeat(
+            cmask.size(0) - 2, 1)
+        video_padding_masks = cls_text_mask[None, :].repeat(
+            vmask.size(0) - video_len, 1)
+
+        return torch.cat([
+            cls_text_mask[None, :],
+            video_masks,
+            video_padding_masks,
+            torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)[None,:],
+            text_masks
+            ], dim=0)
+
+    def _build_textgeneration_mask(self, vmask, cmask):
+        # cls_mask is only about video otherwise it will leak generation.
+        cls_video_mask = torch.cat([
+            # [CLS]
+            torch.ones(
+                (1,), dtype=torch.bool, device=cmask.device),
+            vmask,
+            # [SEP]
+            torch.ones((1,), dtype=torch.bool, device=cmask.device),
+            torch.zeros(
+                (cmask.size(0)-2,), dtype=torch.bool, device=cmask.device)
+        ], dim=0)
+
+        # concat horizontially.
+        text_len = int(cmask[2:].sum())
+        text_masks = torch.cat([
+            # [CLS]
+            torch.ones(
+                (text_len, 1), dtype=torch.bool, device=cmask.device
+            ),
+            vmask.unsqueeze(0).repeat(text_len, 1),
+            # [SEP] for video.
+            torch.ones(
+                (text_len, 1), dtype=torch.bool, device=cmask.device
+            ),
+            torch.tril(
+                torch.ones(
+                    (text_len, text_len),
+                    dtype=torch.bool, device=cmask.device)),
+            # padding.
+            torch.zeros(
+                (text_len, cmask.size(0) - text_len - 2),
+                dtype=torch.bool, device=cmask.device
+            )
+        ], dim=1)
+
+        cls_video_masks = cls_video_mask[None, :].repeat(
+            vmask.size(0) + 2, 1)
+        text_padding_masks = cls_video_mask[None, :].repeat(
+            cmask.size(0) - text_len - 2, 1)
+        return torch.cat([
+            cls_video_masks, text_masks, text_padding_masks], dim=0)

+ 22 - 0
examples/MMPT/mmpt/tasks/__init__.py

@@ -0,0 +1,22 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .task import *
+from .vlmtask import *
+from .retritask import *
+
+try:
+    from .fairseqmmtask import *
+except ImportError:
+    pass
+
+try:
+    from .milncetask import *
+except ImportError:
+    pass
+
+try:
+    from .expretritask import *
+except ImportError:
+    pass

+ 104 - 0
examples/MMPT/mmpt/tasks/fairseqmmtask.py

@@ -0,0 +1,104 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+make a general fairseq task for MM pretraining.
+"""
+
+import random
+
+from fairseq.tasks import LegacyFairseqTask, register_task
+
+from .task import Task
+from .retritask import RetriTask
+from ..datasets import FairseqMMDataset
+from .. import utils
+
+
+@register_task("mmtask")
+class FairseqMMTask(LegacyFairseqTask):
+    @staticmethod
+    def add_args(parser):
+        # Add some command-line arguments for specifying where the data is
+        # located and the maximum supported input length.
+        parser.add_argument(
+            "taskconfig",
+            metavar="FILE",
+            help=("taskconfig to load all configurations" "outside fairseq parser."),
+        )
+
+    @classmethod
+    def setup_task(cls, args, **kwargs):
+        return FairseqMMTask(args)
+
+    def __init__(self, args):
+        super().__init__(args)
+        config = utils.load_config(args)
+        self.mmtask = Task.config_task(config)
+        self.mmtask.build_dataset()
+        self.mmtask.build_model()
+        self.mmtask.build_loss()
+
+    def load_dataset(self, split, **kwargs):
+        split_map = {
+            "train": self.mmtask.train_data,
+            "valid": self.mmtask.val_data,
+            "test": self.mmtask.test_data,
+        }
+        if split not in split_map:
+            raise ValueError("unknown split type.")
+        if split_map[split] is not None:
+            self.datasets[split] = FairseqMMDataset(split_map[split])
+
+    def get_batch_iterator(
+        self,
+        dataset,
+        max_tokens=None,
+        max_sentences=None,
+        max_positions=None,
+        ignore_invalid_inputs=False,
+        required_batch_size_multiple=1,
+        seed=1,
+        num_shards=1,
+        shard_id=0,
+        num_workers=0,
+        epoch=1,
+        data_buffer_size=0,
+        disable_iterator_cache=False,
+        skip_remainder_batch=False,
+        grouped_shuffling=False,
+        update_epoch_batch_itr=False,
+    ):
+        random.seed(epoch)
+        if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask):
+            if epoch >= self.mmtask.config.retri_epoch:
+                if not hasattr(self.mmtask, "retri_dataloader"):
+                    self.mmtask.build_dataloader()
+                self.mmtask.retrive_candidates(epoch)
+
+        return super().get_batch_iterator(
+            dataset,
+            max_tokens,
+            max_sentences,
+            max_positions,
+            ignore_invalid_inputs,
+            required_batch_size_multiple,
+            seed,
+            num_shards,
+            shard_id,
+            num_workers,
+            epoch,
+            data_buffer_size,
+            disable_iterator_cache,
+            grouped_shuffling,
+            update_epoch_batch_itr,
+        )
+
+    @property
+    def source_dictionary(self):
+        return None
+
+    @property
+    def target_dictionary(self):
+        return None

+ 27 - 0
examples/MMPT/mmpt/tasks/milncetask.py

@@ -0,0 +1,27 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from .task import Task
+
+
+class MILNCETask(Task):
+    def reshape_subsample(self, sample):
+        if (
+            hasattr(self.config.dataset, "subsampling")
+            and self.config.dataset.subsampling is not None
+            and self.config.dataset.subsampling > 1
+        ):
+            for key in sample:
+                if torch.is_tensor(sample[key]):
+                    tensor = self.flat_subsample(sample[key])
+                    if key in ["caps", "cmasks"]:
+                        size = tensor.size()
+                        batch_size = size[0] * size[1]
+                        expanded_size = (batch_size,) + size[2:]
+                        tensor = tensor.view(expanded_size)
+                    sample[key] = tensor
+        return sample

+ 253 - 0
examples/MMPT/mmpt/tasks/retritask.py

@@ -0,0 +1,253 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import torch
+import pickle
+import random
+
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from ..processors import (
+    ShardedHow2MetaProcessor,
+    ShardedVideoProcessor,
+    ShardedTextProcessor,
+    VariedLenAligner,
+)
+
+from ..datasets import MMDataset
+from .task import Task
+from ..modules import vectorpool
+from ..evaluators.predictor import Predictor
+from ..utils import set_seed, get_local_rank, get_world_size
+
+
+class RetriTask(Task):
+    """abstract class for task with retrival."""
+
+    def reshape_subsample(self, sample):
+        for key in sample:
+            if torch.is_tensor(sample[key]):
+                sample[key] = self.flat_subsample(sample[key])
+        return sample
+
+    def flat_subsample(self, tensor):
+        if tensor.size(0) == 1:
+            tensor = tensor.squeeze(0)
+        return tensor
+
+    def build_dataloader(self):
+        """called by `get_batch_iterator` in fairseqmmtask. """
+        # TODO: hard-code dataloader for retri for now and configurable in .yaml.
+        # reuse the `train.lst`.
+        self.config.dataset.split = "train"
+        meta_processor = ShardedHow2MetaProcessor(self.config.dataset)
+        video_processor = ShardedVideoProcessor(self.config.dataset)
+        text_processor = ShardedTextProcessor(self.config.dataset)
+
+        aligner = VariedLenAligner(self.config.dataset)
+        aligner.subsampling = self.config.dataset.clip_per_video
+
+        self.retri_data = MMDataset(
+            meta_processor, video_processor, text_processor, aligner
+        )
+
+        retri_sampler = DistributedSampler(self.retri_data)
+        infer_scale = 16
+        batch_size = self.config.dataset.num_video_per_batch \
+            * infer_scale
+
+        self.retri_dataloader = DataLoader(
+            self.retri_data,
+            collate_fn=self.retri_data.collater,
+            batch_size=batch_size,
+            shuffle=False,
+            sampler=retri_sampler,
+            num_workers=self.config.fairseq.dataset.num_workers
+        )
+        return self.retri_dataloader
+
+    def retrive_candidates(self, epoch, dataloader=None):
+        if get_local_rank() == 0:
+            print("running retrieval model.")
+        out_dir = os.path.join(
+            self.config.fairseq.checkpoint.save_dir, "retri")
+        os.makedirs(out_dir, exist_ok=True)
+
+        if not os.path.isfile(
+                os.path.join(
+                    out_dir, "batched_e" + str(epoch) + "_videos0.pkl")
+        ):
+            if dataloader is None:
+                dataloader = self.retri_dataloader
+
+            self.model.eval()
+            self.model.is_train = False
+
+            assert self.retri_data.meta_processor.data == \
+                self.train_data.meta_processor.data  # video_ids not mutated.
+
+            self._retri_predict(epoch, dataloader)
+
+            self.model.train()
+            self.model.is_train = True
+
+        torch.distributed.barrier()
+        output = self._retri_sync(epoch, out_dir)
+        torch.distributed.barrier()
+        self.train_data.meta_processor.set_candidates(output)
+        return output
+
+
+class VideoRetriTask(RetriTask):
+    """RetriTask on video level."""
+
+    def reshape_subsample(self, sample):
+        if (
+            hasattr(self.config.dataset, "clip_per_video")
+            and self.config.dataset.clip_per_video is not None
+            and self.config.dataset.clip_per_video > 1
+        ):
+            for key in sample:
+                if torch.is_tensor(sample[key]):
+                    sample[key] = self.flat_subsample(sample[key])
+        return sample
+
+    def flat_subsample(self, tensor):
+        if tensor.size(0) == 1:
+            tensor = tensor.squeeze(0)
+        return Task.flat_subsample(self, tensor)
+
+    def _retri_predict(self, epoch, dataloader):
+        set_seed(epoch)
+        # save for retrival.
+        predictor = VideoPredictor(self.config)
+        predictor.predict_loop(
+            self.model, dataloader)
+        set_seed(epoch)  # get the same text clips.
+        # retrival.
+        retri_predictor = VideoRetriPredictor(
+            self.config)
+        retri_predictor.predict_loop(
+            self.model, predictor.vecpool.retriver, epoch)
+        del predictor
+        del retri_predictor
+
+    def _retri_sync(self, epoch, out_dir):
+        # gpu do the same merge.
+        batched_videos = []
+        for local_rank in range(get_world_size()):
+            fn = os.path.join(
+                out_dir,
+                "batched_e" + str(epoch) + "_videos" + str(local_rank) + ".pkl")
+            with open(fn, "rb") as fr:
+                batched_videos.extend(pickle.load(fr))
+        print(
+            "[INFO] batched_videos",
+            len(batched_videos), len(batched_videos[0]))
+        return batched_videos
+
+
+class VideoPredictor(Predictor):
+    def __init__(self, config):
+        vectorpool_cls = getattr(vectorpool, config.vectorpool_cls)
+        self.vecpool = vectorpool_cls(config)
+
+    def predict_loop(
+        self,
+        model,
+        dataloader,
+        early_stop=-1,
+    ):
+        with torch.no_grad():
+            if get_local_rank() == 0:
+                dataloader = tqdm(dataloader)
+            for batch_idx, batch in enumerate(dataloader):
+                if batch_idx == early_stop:
+                    break
+                self(batch, model)
+        return self.finalize()
+
+    def __call__(self, sample, model, **kwargs):
+        param = next(model.parameters())
+        dtype = param.dtype
+        device = param.device
+        subsample = sample["vfeats"].size(1)
+        sample = self.to_ctx(sample, device, dtype)
+        for key in sample:
+            if torch.is_tensor(sample[key]):
+                size = sample[key].size()
+                if len(size) >= 2:
+                    batch_size = size[0] * size[1]
+                    expanded_size = (
+                        (batch_size,) + size[2:] if len(size) > 2
+                        else (batch_size,)
+                    )
+                    sample[key] = sample[key].view(expanded_size)
+
+        outputs = model(**sample)
+        sample.update(outputs)
+        self.vecpool(sample, subsample)
+
+    def finalize(self):
+        print("[INFO]", self.vecpool)
+        if not self.vecpool.retriver.db.is_trained:
+            self.vecpool.retriver.finalize_training()
+        return self.vecpool.retriver
+
+
+class VideoRetriPredictor(Predictor):
+    """
+    Online Retrieval Predictor for Clips (used by RetriTask).
+    TODO: merge this with VisPredictor?
+    """
+
+    def __init__(self, config):
+        self.pred_dir = os.path.join(
+            config.fairseq.checkpoint.save_dir,
+            "retri")
+        self.num_cands = config.num_cands
+        self.num_video_per_batch = config.dataset.num_video_per_batch
+
+    def predict_loop(
+        self,
+        model,
+        retriver,
+        epoch,
+        early_stop=-1
+    ):
+        # a fake loop that only try to recover video vector
+        # from video_id.
+        batched_videos = []
+        # obtain available video_ids.
+        video_ids = list(retriver.videoid_to_vectoridx.keys())
+
+        dataloader = random.sample(
+            video_ids,
+            len(video_ids) // self.num_video_per_batch
+        )
+
+        if get_local_rank() == 0:
+            dataloader = tqdm(dataloader)
+        for batch_idx, batch in enumerate(dataloader):
+            # batch is one video id.
+            if batch_idx == early_stop:
+                break
+            video_ids = retriver.search_by_video_ids(
+                [batch], self.num_cands)[0]
+            if len(video_ids) > self.num_video_per_batch:
+                # we moved the center to make cluster robust.
+                video_ids = random.sample(video_ids, self.num_video_per_batch)
+            batched_videos.append(video_ids)
+        return self.finalize(batched_videos, epoch)
+
+    def finalize(self, batched_videos, epoch):
+        fn = os.path.join(
+            self.pred_dir,
+            "batched_e" + str(epoch) + "_videos" + str(get_local_rank()) + ".pkl")
+        with open(fn, "wb") as fw:
+            pickle.dump(batched_videos, fw, pickle.HIGHEST_PROTOCOL)
+        return batched_videos

+ 184 - 0
examples/MMPT/mmpt/tasks/task.py

@@ -0,0 +1,184 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from .. import tasks
+from .. import models
+from .. import losses
+from ..datasets import MMDataset
+from .. import processors
+
+
+class Task(object):
+    """
+    A task refers to one generic training task (e.g., training one model).
+    """
+
+    @classmethod
+    def config_task(cls, config):
+        """
+        determine whether to load a hard-coded task or config from a generic one.
+        via if a task string is available in config.
+        """
+        if config.task is not None:
+            # TODO (huxu): expand the search scope.
+            task_cls = getattr(tasks, config.task)
+            return task_cls(config)
+        else:
+            return Task(config)
+
+    def __init__(self, config):
+        self.config = config
+        self.train_data = None
+        self.val_data = None
+        self.test_data = None
+
+        self.model = None
+        self.loss_fn = None
+        self.eval_fn = None
+
+    def build_dataset(self):
+        """TODO (huxu): move processor breakdown to MMDataset."""
+        """fill-in `self.train_data`, `self.val_data` and `self.test_data`."""
+
+        meta_processor_cls = getattr(
+            processors, self.config.dataset.meta_processor)
+        video_processor_cls = getattr(
+            processors, self.config.dataset.video_processor)
+        text_processor_cls = getattr(
+            processors, self.config.dataset.text_processor)
+        aligner_cls = getattr(
+            processors, self.config.dataset.aligner)
+
+        if self.config.dataset.train_path is not None:
+            self.config.dataset.split = "train"
+            # may be used by meta processor.
+            # meta_processor controls different dataset.
+            meta_processor = meta_processor_cls(self.config.dataset)
+            video_processor = video_processor_cls(self.config.dataset)
+            text_processor = text_processor_cls(self.config.dataset)
+            aligner = aligner_cls(self.config.dataset)
+            self.train_data = MMDataset(
+                meta_processor, video_processor, text_processor, aligner
+            )
+            print("train_len", len(self.train_data))
+            output = self.train_data[0]
+            self.train_data.print_example(output)
+        if self.config.dataset.val_path is not None:
+            self.config.dataset.split = "valid"
+            # may be used by meta processor.
+            meta_processor = meta_processor_cls(self.config.dataset)
+            video_processor = video_processor_cls(self.config.dataset)
+            text_processor = text_processor_cls(self.config.dataset)
+            aligner = aligner_cls(self.config.dataset)
+            self.val_data = MMDataset(
+                meta_processor, video_processor, text_processor, aligner
+            )
+            print("val_len", len(self.val_data))
+            output = self.val_data[0]
+            self.val_data.print_example(output)
+
+        if self.config.dataset.split == "test":
+            # the following is run via lauching fairseq-validate.
+            meta_processor = meta_processor_cls(self.config.dataset)
+            video_processor = video_processor_cls(self.config.dataset)
+            text_processor = text_processor_cls(self.config.dataset)
+
+            self.test_data = MMDataset(
+                meta_processor, video_processor, text_processor, aligner
+            )
+            print("test_len", len(self.test_data))
+            output = self.test_data[0]
+            self.test_data.print_example(output)
+
+    def build_model(self, checkpoint=None):
+        if self.model is None:
+            model_cls = getattr(models, self.config.model.model_cls)
+            self.model = model_cls(self.config)
+        if checkpoint is not None:
+            self.load_checkpoint(checkpoint)
+        return self.model
+
+    def load_checkpoint(self, checkpoint):
+        if self.model is None:
+            raise ValueError("model is not initialized.")
+        state_dict = torch.load(checkpoint)
+        state_dict = self._trim_state_dict(state_dict)
+        self.model.load_state_dict(state_dict, strict=False)
+        # if it's a fp16 model, turn it back.
+        if next(self.model.parameters()).dtype == torch.float16:
+            self.model = self.model.float()
+        return self.model
+
+    def _trim_state_dict(self, state_dict):
+        from collections import OrderedDict
+
+        if "state_dict" in state_dict:
+            state_dict = state_dict["state_dict"]
+        if "model" in state_dict:  # fairseq checkpoint format.
+            state_dict = state_dict["model"]
+        ret_state_dict = OrderedDict()
+        for (
+            key,
+            value,
+        ) in state_dict.items():
+            # remove fairseq wrapper since this is a task.
+            if key.startswith("mmmodel"):
+                key = key[len("mmmodel."):]
+            ret_state_dict[key] = value
+        return ret_state_dict
+
+    def build_loss(self):
+        if self.loss_fn is None and self.config.loss is not None:
+            loss_cls = getattr(losses, self.config.loss.loss_cls)
+            self.loss_fn = loss_cls()
+        return self.loss_fn
+
+    def flat_subsample(self, tensor):
+        size = tensor.size()
+        if len(size) >= 2:
+            batch_size = size[0] * size[1]
+            expanded_size = (
+                (batch_size,) + size[2:] if len(size) > 2
+                else (batch_size,)
+            )
+            tensor = tensor.view(expanded_size)
+        return tensor
+
+    def reshape_subsample(self, sample):
+        if (
+            hasattr(self.config.dataset, "subsampling")
+            and self.config.dataset.subsampling is not None
+            and self.config.dataset.subsampling > 1
+        ):
+            for key in sample:
+                if torch.is_tensor(sample[key]):
+                    sample[key] = self.flat_subsample(sample[key])
+        return sample
+
+    def __call__(self, model, sample):
+        loss = None
+        loss_scalar = float("inf")
+
+        sample = self.reshape_subsample(sample)
+        outputs = self.model(**sample)
+        sample.update(outputs)
+        if self.loss_fn is not None:
+            loss = self.loss_fn(**sample)
+            loss_scalar = loss.item()
+
+        batch_size = sample["caps"].size(0)
+        sample_size = 1
+        return {
+            "loss": loss,
+            "loss_scalar": loss_scalar,
+            "max_len": self.config.dataset.max_len,
+            "batch_size": batch_size,
+            "sample_size": sample_size,
+        }
+
+    def build_dataloader(self):
+        """only used for trainer that lacks building loaders."""
+        raise NotImplementedError

+ 27 - 0
examples/MMPT/mmpt/tasks/vlmtask.py

@@ -0,0 +1,27 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from .task import Task
+
+
+class VLMTask(Task):
+    """A VLM task for reproducibility.
+    the collator split subsamples into two sub-batches.
+    This has should have no logic changes.
+    but changed the randomness in frame masking.
+    """
+
+    def flat_subsample(self, tensor):
+        size = tensor.size()
+        if len(size) >= 2:
+            batch_size = size[0] * (size[1] // 2)
+            expanded_size = (
+                (batch_size, 2) + size[2:] if len(size) > 2
+                else (batch_size, 2)
+            )
+            tensor = tensor.view(expanded_size)
+            tensor = torch.cat([tensor[:, 0], tensor[:, 1]], dim=0)
+        return tensor

+ 68 - 0
examples/MMPT/mmpt/utils/__init__.py

@@ -0,0 +1,68 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import random
+import numpy as np
+import torch
+
+from .shardedtensor import *
+from .load_config import *
+
+
+def set_seed(seed=43211):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    if torch.backends.cudnn.enabled:
+        torch.backends.cudnn.benchmark = False
+        torch.backends.cudnn.deterministic = True
+
+
+def get_world_size():
+    if torch.distributed.is_initialized():
+        world_size = torch.distributed.get_world_size()
+    else:
+        world_size = 1
+    return world_size
+
+
+def get_local_rank():
+    return torch.distributed.get_rank() \
+        if torch.distributed.is_initialized() else 0
+
+
+def print_on_rank0(func):
+    local_rank = get_local_rank()
+    if local_rank == 0:
+        print("[INFO]", func)
+
+
+class RetriMeter(object):
+    """
+    Statistics on whether retrieval yields a better pair.
+    """
+    def __init__(self, freq=1024):
+        self.freq = freq
+        self.total = 0
+        self.replace = 0
+        self.updates = 0
+
+    def __call__(self, data):
+        if isinstance(data, np.ndarray):
+            self.replace += data.shape[0] - int((data[:, 0] == -1).sum())
+            self.total += data.shape[0]
+        elif torch.is_tensor(data):
+            self.replace += int(data.sum())
+            self.total += data.size(0)
+        else:
+            raise ValueError("unsupported RetriMeter data type.", type(data))
+
+        self.updates += 1
+        if get_local_rank() == 0 and self.updates % self.freq == 0:
+            print("[INFO]", self)
+
+    def __repr__(self):
+        return "RetriMeter (" + str(self.replace / self.total) \
+            + "/" + str(self.replace) + "/" + str(self.total) + ")"

+ 81 - 0
examples/MMPT/mmpt/utils/load_config.py

@@ -0,0 +1,81 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import omegaconf
+from omegaconf import OmegaConf
+
+
+def load_config(args=None, config_file=None, overwrite_fairseq=False):
+    """TODO (huxu): move fairseq overwrite to another function."""
+    if args is not None:
+        config_file = args.taskconfig
+    config = recursive_config(config_file)
+
+    if config.dataset.subsampling is not None:
+        batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling
+        print(
+            "adjusting batch_size to {} due to subsampling {}.".format(
+                batch_size, config.dataset.subsampling
+            )
+        )
+        config.fairseq.dataset.batch_size = batch_size
+
+    is_test = config.dataset.split is not None and config.dataset.split == "test"
+    if not is_test:
+        if (
+            config.fairseq.checkpoint is None
+            or config.fairseq.checkpoint.save_dir is None
+        ):
+            raise ValueError("fairseq save_dir or save_path must be specified.")
+
+        save_dir = config.fairseq.checkpoint.save_dir
+        os.makedirs(save_dir, exist_ok=True)
+        if config.fairseq.common.tensorboard_logdir is not None:
+            tb_run_dir = suffix_rundir(
+                save_dir, config.fairseq.common.tensorboard_logdir
+            )
+            config.fairseq.common.tensorboard_logdir = tb_run_dir
+            print(
+                "update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir
+            )
+        os.makedirs(save_dir, exist_ok=True)
+        OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml"))
+
+    if overwrite_fairseq and config.fairseq is not None and args is not None:
+        # flatten fields.
+        for group in config.fairseq:
+            for field in config.fairseq[group]:
+                print("overwrite args." + field, "as", config.fairseq[group][field])
+                setattr(args, field, config.fairseq[group][field])
+    return config
+
+
+def recursive_config(config_path):
+    """allows for stacking of configs in any depth."""
+    config = OmegaConf.load(config_path)
+    if config.includes is not None:
+        includes = config.includes
+        config.pop("includes")
+        base_config = recursive_config(includes)
+        config = OmegaConf.merge(base_config, config)
+    return config
+
+
+def suffix_rundir(save_dir, run_dir):
+    max_id = -1
+    for search_dir in os.listdir(save_dir):
+        if search_dir.startswith(run_dir):
+            splits = search_dir.split("_")
+            cur_id = int(splits[1]) if len(splits) > 1 else 0
+            max_id = max(max_id, cur_id)
+    return os.path.join(save_dir, run_dir + "_" + str(max_id + 1))
+
+
+def overwrite_dir(config, replace, basedir):
+    for key in config:
+        if isinstance(config[key], str) and config[key].startswith(basedir):
+            config[key] = config[key].replace(basedir, replace)
+        if isinstance(config[key], omegaconf.dictconfig.DictConfig):
+            overwrite_dir(config[key], replace, basedir)

+ 46 - 0
examples/MMPT/mmpt/utils/shardedtensor.py

@@ -0,0 +1,46 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import pickle
+import numpy as np
+
+
+class ShardedTensor(object):
+    def __init__(self, data, starts):
+        self.data = data
+        self.starts = starts
+        assert self.starts[0] == 0
+        assert self.starts[-1] == len(self.data)
+        assert (self.starts[1:] >= self.starts[:-1]).all()
+        assert (self.starts > -1).all()
+
+    @staticmethod
+    def from_list(xs):
+        starts = np.full((len(xs) + 1,), -1, dtype=np.long)
+        data = np.concatenate(xs, axis=0)
+        starts[0] = 0
+        for i, x in enumerate(xs):
+            starts[i + 1] = starts[i] + x.shape[0]
+        assert (starts > -1).all()
+        return ShardedTensor(data, starts)
+
+    def __getitem__(self, i):
+        return self.data[self.starts[i] : self.starts[i + 1]]
+
+    def __len__(self):
+        return len(self.starts) - 1
+
+    def lengths(self):
+        return self.starts[1:] - self.starts[:-1]
+
+    def save(self, path):
+        np.save(path + "_starts", self.starts)
+        np.save(path + "_data", self.data)
+
+    @staticmethod
+    def load(path, mmap_mode=None):
+        starts = np.load(path + "_starts.npy", mmap_mode)
+        data = np.load(path + "_data.npy", mmap_mode)
+        return ShardedTensor(data, starts)

+ 117 - 0
examples/MMPT/mmpt_cli/localjob.py

@@ -0,0 +1,117 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+
+from mmpt.utils import recursive_config
+
+
+class BaseJob(object):
+    def __init__(self, yaml_file, dryrun=False):
+        self.yaml_file = yaml_file
+        self.config = recursive_config(yaml_file)
+        self.dryrun = dryrun
+
+    def submit(self, **kwargs):
+        raise NotImplementedError
+
+    def _normalize_cmd(self, cmd_list):
+        cmd_list = list(cmd_list)
+        yaml_index = cmd_list.index("[yaml]")
+        cmd_list[yaml_index] = self.yaml_file
+        return cmd_list
+
+
+class LocalJob(BaseJob):
+
+    CMD_CONFIG = {
+        "local_single": [
+            "fairseq-train", "[yaml]", "--user-dir", "mmpt",
+            "--task", "mmtask", "--arch", "mmarch",
+            "--criterion", "mmloss",
+        ],
+        "local_small": [
+            "fairseq-train", "[yaml]", "--user-dir", "mmpt",
+            "--task", "mmtask", "--arch", "mmarch",
+            "--criterion", "mmloss",
+            "--distributed-world-size", "2"
+        ],
+        "local_big": [
+            "fairseq-train", "[yaml]", "--user-dir", "mmpt",
+            "--task", "mmtask", "--arch", "mmarch",
+            "--criterion", "mmloss",
+            "--distributed-world-size", "8"
+        ],
+        "local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"],
+    }
+
+    def __init__(self, yaml_file, job_type=None, dryrun=False):
+        super().__init__(yaml_file, dryrun)
+        if job_type is None:
+            self.job_type = "local_single"
+            if self.config.task_type is not None:
+                self.job_type = self.config.task_type
+        else:
+            self.job_type = job_type
+        if self.job_type in ["local_single", "local_small"]:
+            if self.config.fairseq.dataset.batch_size > 32:
+                print("decreasing batch_size to 32 for local testing?")
+
+    def submit(self):
+        cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type])
+        if "predict" not in self.job_type:
+            # append fairseq args.
+            from mmpt.utils import load_config
+
+            config = load_config(config_file=self.yaml_file)
+            for field in config.fairseq:
+                for key in config.fairseq[field]:
+                    if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]:  # a list of binary flag.
+                        param = ["--" + key.replace("_", "-")]
+                    else:
+                        if key == "lr":
+                            value = str(config.fairseq[field][key][0])
+                        elif key == "adam_betas":
+                            value = "'"+str(config.fairseq[field][key])+"'"
+                        else:
+                            value = str(config.fairseq[field][key])
+                        param = [
+                            "--" + key.replace("_", "-"),
+                            value
+                        ]
+                    cmd_list.extend(param)
+
+        print("launching", " ".join(cmd_list))
+        if not self.dryrun:
+            os.system(" ".join(cmd_list))
+        return JobStatus("12345678")
+
+
+class JobStatus(object):
+    def __init__(self, job_id):
+        self.job_id = job_id
+
+    def __repr__(self):
+        return self.job_id
+
+    def __str__(self):
+        return self.job_id
+
+    def done(self):
+        return False
+
+    def running(self):
+        return False
+
+    def result(self):
+        if self.done():
+            return "{} is done.".format(self.job_id)
+        else:
+            return "{} is running.".format(self.job_id)
+
+    def stderr(self):
+        return self.result()
+
+    def stdout(self):
+        return self.result()

+ 113 - 0
examples/MMPT/mmpt_cli/predict.py

@@ -0,0 +1,113 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import glob
+import argparse
+import pprint
+import omegaconf
+
+from omegaconf import OmegaConf
+from torch.utils.data import DataLoader
+
+from mmpt.utils import load_config, set_seed
+from mmpt.evaluators import Evaluator
+from mmpt.evaluators import predictor as predictor_path
+from mmpt.tasks import Task
+from mmpt import processors
+from mmpt.datasets import MMDataset
+
+
+def get_dataloader(config):
+    meta_processor_cls = getattr(processors, config.dataset.meta_processor)
+    video_processor_cls = getattr(processors, config.dataset.video_processor)
+    text_processor_cls = getattr(processors, config.dataset.text_processor)
+    aligner_cls = getattr(processors, config.dataset.aligner)
+
+    meta_processor = meta_processor_cls(config.dataset)
+    video_processor = video_processor_cls(config.dataset)
+    text_processor = text_processor_cls(config.dataset)
+    aligner = aligner_cls(config.dataset)
+
+    test_data = MMDataset(
+        meta_processor,
+        video_processor,
+        text_processor,
+        aligner,
+    )
+    print("test_len", len(test_data))
+    output = test_data[0]
+    test_data.print_example(output)
+
+    test_dataloader = DataLoader(
+        test_data,
+        batch_size=config.fairseq.dataset.batch_size,
+        shuffle=False,
+        num_workers=6,
+        collate_fn=test_data.collater,
+    )
+    return test_dataloader
+
+
+def main(args):
+    config = load_config(args)
+
+    if isinstance(config, omegaconf.dictconfig.DictConfig):
+        print(OmegaConf.to_yaml(config))
+    else:
+        pp = pprint.PrettyPrinter(indent=4)
+        pp.print(config)
+
+    mmtask = Task.config_task(config)
+    mmtask.build_model()
+
+    test_dataloader = get_dataloader(config)
+    checkpoint_search_path = os.path.dirname(config.eval.save_path)
+    results = []
+
+    prefix = os.path.basename(args.taskconfig)
+    if prefix.startswith("test"):
+        # loop all checkpoint for datasets without validation set.
+        if "best" not in config.fairseq.common_eval.path:
+            print("eval each epoch.")
+            for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"):
+                model = mmtask.load_checkpoint(checkpoint)
+                ckpt = os.path.basename(checkpoint)
+                evaluator = Evaluator(config)
+                output = evaluator.evaluate(
+                    model, test_dataloader, ckpt + "_merged")
+                results.append((checkpoint, output))
+        # use the one specified by the config lastly.
+        model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
+        evaluator = Evaluator(config)
+        output = evaluator.evaluate(model, test_dataloader)
+        results.append((config.fairseq.common_eval.path, output))
+
+        best_result = None
+        best_metric = 0.
+        for checkpoint, result in results:
+            print(checkpoint)
+            evaluator.metric.print_computed_metrics(result)
+            best_score = evaluator.metric.best_metric(result)
+            if best_score > best_metric:
+                best_result = (checkpoint, result)
+                best_metric = best_score
+        print("best results:")
+        print(best_result[0])
+        evaluator.metric.print_computed_metrics(best_result[1])
+
+    elif prefix.startswith("vis"):
+        model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
+        predictor_cls = getattr(predictor_path, config.predictor)
+        predictor = predictor_cls(config)
+        predictor.predict_loop(model, test_dataloader, mmtask, None)
+    else:
+        raise ValueError("unknown prefix of the config file", args.taskconfig)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("taskconfig", type=str)
+    args = parser.parse_args()
+    main(args)

+ 29 - 0
examples/MMPT/pretraining.md

@@ -0,0 +1,29 @@
+# Pretraining
+
+(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
+We mostly use [howto100M](https://github.com/antoine77340/howto100m) dataset for pretraining (other datasets are coming). So you are less likely to write a new `MetaProcessor`, `VideoProcessor` or `TextProcessor` but only working on a new `Aligner`, a new model and loss.
+
+### Data Sharding
+Pretraining on Howto100M is heavy on IO since we have millions of videos or captions on the hard disk that cannot be fit into the memory. 
+It is desirable to have an optimized preprocessing step before the actual dataloading.  
+
+We support data sharding to pack multiple videos into a shards of training data for both videos and captions. (see [dataset](DATASET.md) for preprocessing).
+These shards will be mapped into memory to reduce the frequency of IO access on millions of files. See (processors starting with `Sharded*`).
+This will be the default config for a how2 dataset `projects/task/how2.yaml`.
+
+Great thanks to Dmytro Okhonko for sharing the code from MARGE project.
+
+### Training
+Pretraining on Howto100m is expected on one or multiple nodes, where each node has 8 GPUS with 32 GB mem.
+launching a pretraing on MFM+MLM can be done, via:  
+```python locallaunch.py projects/mfmmlm/how2.yaml```
+
+### Pre-training with a Retrieval Model (VideoCLIP)
+This projects now support alternatively run a retrieval model and pre-training.
+We implement a basic retrieval model that is built on the hidden states of a video and faiss.
+
+You may need to install faiss via `conda install faiss-cpu -c pytorch`.  
+
+Right now, the hidden states of a video is computed as the average of 8 clips of their pooled visual/text hidden states.
+See `mmpt/tasks/retritask.py` for more details.
+The `.yaml` config for running pre-training with a retrieval model can be found at `projects/retri/videoretri.yaml`.

+ 59 - 0
examples/MMPT/projects/mfmmlm.yaml

@@ -0,0 +1,59 @@
+project_dir: mfmmlm
+run_task:
+  - how2.yaml
+  - [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml]
+base_dir: task
+task_group:
+  pretrain:
+    task_list:
+      - how2.yaml
+    dataset:
+      subsampling: 32
+      sampled_min_len: 10
+      sampled_max_len: 64
+      max_video_len: 32
+      max_len: 96
+      aligner: MFMMLMAligner
+      lazy_vfeat_mask: True
+      mfm_probability: 0.15
+      mlm_probability: 0.15
+      mm_prob: 0.5
+    model:
+      model_cls: MMFusionMFMMLM
+      mm_encoder_cls: MMFusionForMFMMLM
+    loss:
+      loss_cls: MFMMLM
+    fairseq:
+      common:
+        fp16: true
+      dataset:
+        batch_size: 256
+      optimization:
+        max_epoch: 15     
+  finetune:
+    task_list:
+      - vtt.yaml
+      - vttqa.yaml
+      - youcook.yaml
+      - youcookcap.yaml
+      - crosstask.yaml
+      - coin.yaml
+    dataset:
+      max_video_len: 32
+      max_len: 96
+    fairseq:
+      common:
+        fp16: true
+    # do not write any model or loss here (they are expected to be fixed in mmfusion).
+  test:
+    task_list:
+      - test_vtt.yaml
+      - test_vttqa.yaml
+      - test_youcook.yaml
+      - test_youcookcap.yaml
+      - test_crosstask.yaml
+      - test_crosstask_zs.yaml
+      - test_coin.yaml
+    dataset:
+      max_video_len: 32
+      max_len: 96

+ 19 - 0
examples/MMPT/projects/mtm/mmfusionmtm.yaml

@@ -0,0 +1,19 @@
+includes: projects/mfmmlm.yaml
+project_dir: mtm/mmfusionmtm
+task_group:
+  pretrain:
+    task: VLMTask  # reproducible
+    dataset:
+      aligner: MFMMLMAligner
+    model:
+      use_seg_emb: True  # reproducible
+      model_cls: MMFusionMTM
+      mm_encoder_cls: MMBertForMFMMLM
+    loss:
+      loss_cls: MTM
+  finetune:
+    model:
+      use_seg_emb: True  # reproducible
+  test:
+    model:
+      use_seg_emb: True  # reproducible

+ 8 - 0
examples/MMPT/projects/mtm/vlm.yaml

@@ -0,0 +1,8 @@
+includes: projects/mtm/mmfusionmtm.yaml
+project_dir: mtm/vlm
+task_group:
+  pretrain:
+    dataset:
+      sampled_min_len: 8
+    loss:
+      loss_cls: MTM

+ 47 - 0
examples/MMPT/projects/mtm/vlm/coin.yaml

@@ -0,0 +1,47 @@
+dataset:
+  video_processor: VideoProcessor
+  bert_name: bert-base-uncased
+  meta_processor: COINActionSegmentationMetaProcessor
+  train_path: data/coin/COIN.json
+  val_path: data/coin/COIN.json
+  vfeat_dir: data/feat/feat_coin_s3d
+  text_processor: COINActionSegmentationTextProcessor
+  aligner: COINActionSegmentationAligner
+  num_iso_layer: 12
+  sliding_window: 8
+  sliding_window_size: 32
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  common:
+    tensorboard_logdir: run
+    log_interval: 1000
+    fp16: true
+  dataset:
+    num_workers: 4
+    batch_size: 1
+  optimization:
+    lr:
+    - 5.0e-05
+    clip_norm: 2.0
+    optimizer: adam
+    adam_betas: (0.9, 0.98)
+    lr_scheduler: polynomial_decay
+    total_num_update: 1000000
+    warmup_updates: 122
+    weight_decay: 0.0
+    ddp_backend: no_c10d
+    max_epoch: 8
+  checkpoint:
+    restore_file: runs/mtm/vlm/checkpoint_best.pt
+    reset_optimizer: true
+    reset_dataloader: true
+    reset_meters: true
+    save_dir: runs/mtm/vlm/coin
+task_type: sweep_big
+model:
+  model_cls: MMFusionActionSegmentation
+  mm_encoder_cls: MMBertForTokenClassification
+  use_seg_emb: true
+loss:
+  loss_cls: CrossEntropy

+ 53 - 0
examples/MMPT/projects/mtm/vlm/crosstask.yaml

@@ -0,0 +1,53 @@
+dataset:
+  video_processor: CrossTaskVideoProcessor
+  bert_name: bert-base-uncased
+  meta_processor: CrossTaskMetaProcessor
+  train_path: data/crosstask/crosstask_release/videos.csv
+  train_csv_path: data/crosstask/crosstask_release/videos.csv
+  val_path: data/crosstask/crosstask_release/videos_val.csv
+  val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+  primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+  related_path: data/crosstask/crosstask_release/tasks_related.txt
+  vfeat_dir: data/feat/feat_crosstask_s3d
+  annotation_path: data/crosstask/crosstask_release/annotations
+  n_train: 30
+  text_processor: CrossTaskTextProcessor
+  aligner: CrossTaskAligner
+  num_iso_layer: 12
+  sliding_window: 16
+  sliding_window_size: 32
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  common:
+    tensorboard_logdir: run
+    log_interval: 1000
+    fp16: true
+  dataset:
+    num_workers: 4
+    batch_size: 1
+  optimization:
+    lr:
+    - 5.0e-05
+    clip_norm: 2.0
+    optimizer: adam
+    adam_betas: (0.9, 0.98)
+    lr_scheduler: polynomial_decay
+    total_num_update: 1000000
+    warmup_updates: 122
+    weight_decay: 0.0
+    ddp_backend: no_c10d
+    max_epoch: 5
+  checkpoint:
+    restore_file: runs/mtm/vlm/checkpoint11.pt
+    reset_optimizer: true
+    reset_dataloader: true
+    reset_meters: true
+    save_dir: runs/mtm/vlm/crosstask
+task_type: sweep_small
+model:
+  model_cls: MMFusionActionLocalization
+  mm_encoder_cls: MMBertForJoint
+  use_seg_emb: true
+loss:
+  loss_cls: BCE

+ 55 - 0
examples/MMPT/projects/mtm/vlm/how2.yaml

@@ -0,0 +1,55 @@
+dataset:
+  video_processor: ShardedVideoProcessor
+  bert_name: bert-base-uncased
+  meta_processor: ShardedHow2MetaProcessor
+  train_path: data/how2/how2_s3d_train.lst
+  val_path: data/how2/how2_s3d_val.lst
+  vfeat_dir: data/feat/feat_how2_s3d_shard_small
+  text_processor: ShardedTextProcessor
+  tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
+  aligner: MFMMLMAligner
+  subsampling: 32
+  sampled_min_len: 8
+  sampled_max_len: 64
+  max_video_len: 32
+  max_len: 96
+  lazy_vfeat_mask: true
+  mfm_probability: 0.15
+  mlm_probability: 0.15
+  mm_prob: 0.5
+fairseq:
+  common:
+    tensorboard_logdir: run
+    log_interval: 1000
+    fp16: true
+  dataset:
+    num_workers: 4
+    batch_size: 256
+  optimization:
+    lr:
+    - 5.0e-05
+    clip_norm: 2.0
+    optimizer: adam
+    adam_betas: (0.9, 0.98)
+    lr_scheduler: polynomial_decay
+    total_num_update: 1000000
+    warmup_updates: 1000
+    weight_decay: 0.0
+    ddp_backend: no_c10d
+    max_epoch: 15
+  checkpoint:
+    save_dir: runs/mtm/vlm
+    save_interval_updates: 1024
+    keep_interval_updates: 2
+    keep_last_epochs: 30
+task_type: sweep_big
+slurm_config: big
+eval:
+  save_path: runs/mtm/vlm
+model:
+  model_cls: MMFusionMTM
+  mm_encoder_cls: MMBertForMFMMLM
+  use_seg_emb: true
+loss:
+  loss_cls: MTM
+task: VLMTask

+ 31 - 0
examples/MMPT/projects/mtm/vlm/test_coin.yaml

@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+  split: test
+  video_processor: VideoProcessor
+  aligner: COINActionSegmentationAligner
+  bert_name: bert-base-uncased
+  test_path: data/coin/COIN.json
+  meta_processor: COINActionSegmentationMetaProcessor
+  vfeat_dir: data/feat/feat_coin_s3d
+  text_processor: COINActionSegmentationTextProcessor
+  num_iso_layer: 12
+  sliding_window: 16
+  sliding_window_size: 32
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  dataset:
+    batch_size: 1
+    valid_subset: test
+    num_workers: 2
+  common_eval:
+    path: runs/mtm/vlm/coin/checkpoint_best.pt
+model:
+  model_cls: MMFusionActionSegmentation
+  mm_encoder_cls: MMBertForTokenClassification
+  use_seg_emb: true
+eval:
+  save_path: runs/mtm/vlm/coin/eval
+metric: COINActionSegmentationMetric
+predictor: COINPredictor

+ 38 - 0
examples/MMPT/projects/mtm/vlm/test_crosstask.yaml

@@ -0,0 +1,38 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+  split: test
+  video_processor: CrossTaskVideoProcessor
+  aligner: CrossTaskAligner
+  bert_name: bert-base-uncased
+  meta_processor: CrossTaskMetaProcessor
+  test_path: data/crosstask/crosstask_release/videos_val.csv
+  train_csv_path: data/crosstask/crosstask_release/videos.csv
+  val_path: data/crosstask/crosstask_release/videos_val.csv
+  val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+  primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+  related_path: data/crosstask/crosstask_release/tasks_related.txt
+  vfeat_dir: data/feat/feat_crosstask_s3d
+  annotation_path: data/crosstask/crosstask_release/annotations
+  n_train: 30
+  text_processor: CrossTaskTextProcessor
+  num_iso_layer: 12
+  sliding_window: 16
+  sliding_window_size: 32
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  dataset:
+    batch_size: 1
+    valid_subset: test
+    num_workers: 2
+  common_eval:
+    path: runs/mtm/vlm/crosstask/checkpoint_best.pt
+model:
+  model_cls: MMFusionActionLocalization
+  mm_encoder_cls: MMBertForJoint
+  use_seg_emb: true
+eval:
+  save_path: runs/mtm/vlm/crosstask/eval
+metric: CrossTaskMetric
+predictor: CrossTaskPredictor

+ 38 - 0
examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml

@@ -0,0 +1,38 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+  split: test
+  video_processor: CrossTaskVideoProcessor
+  aligner: CrossTaskAligner
+  bert_name: bert-base-uncased
+  meta_processor: CrossTaskMetaProcessor
+  test_path: data/crosstask/crosstask_release/videos_val.csv
+  train_csv_path: data/crosstask/crosstask_release/videos.csv
+  val_path: data/crosstask/crosstask_release/videos_val.csv
+  val_csv_path: data/crosstask/crosstask_release/videos_val.csv
+  primary_path: data/crosstask/crosstask_release/tasks_primary.txt
+  related_path: data/crosstask/crosstask_release/tasks_related.txt
+  vfeat_dir: data/feat/feat_crosstask_s3d
+  annotation_path: data/crosstask/crosstask_release/annotations
+  n_train: 30
+  text_processor: CrossTaskTextProcessor
+  num_iso_layer: 12
+  sliding_window: 16
+  sliding_window_size: 32
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  dataset:
+    batch_size: 1
+    valid_subset: test
+    num_workers: 2
+  common_eval:
+    path: runs/mtm/vlm/checkpoint_best.pt
+model:
+  model_cls: MMFusionActionLocalization
+  mm_encoder_cls: MMBertForJoint
+  use_seg_emb: true
+eval:
+  save_path: runs/mtm/vlm/crosstask_zs/eval
+metric: CrossTaskMetric
+predictor: CrossTaskPredictor

+ 29 - 0
examples/MMPT/projects/mtm/vlm/test_vtt.yaml

@@ -0,0 +1,29 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+  split: test
+  video_processor: VideoProcessor
+  aligner: DSAligner
+  bert_name: bert-base-uncased
+  meta_processor: MSRVTTMetaProcessor
+  test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
+  vfeat_dir: data/feat/feat_vtt_s3d
+  text_processor: MSRVTTTextProcessor
+  num_iso_layer: 12
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  dataset:
+    batch_size: 256
+    valid_subset: test
+    num_workers: 2
+  common_eval:
+    path: runs/mtm/vlm/vtt/checkpoint_last.pt
+model:
+  model_cls: MMFusionJoint
+  mm_encoder_cls: MMBertForJoint
+  use_seg_emb: true
+eval:
+  save_path: runs/mtm/vlm/vtt/eval
+metric: RetrievalMetric
+predictor: RetrievalPredictor

+ 29 - 0
examples/MMPT/projects/mtm/vlm/test_vttqa.yaml

@@ -0,0 +1,29 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+  split: test
+  video_processor: VideoProcessor
+  aligner: MSRVTTQAAligner
+  bert_name: bert-base-uncased
+  meta_processor: MSRVTTQAMetaProcessor
+  test_path: data/msrvtt-qa/MSR_MC_test.csv
+  vfeat_dir: data/feat/feat_vtt_s3d
+  text_processor: MSRVTTQATextProcessor
+  num_iso_layer: 12
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  dataset:
+    batch_size: 256
+    valid_subset: test
+    num_workers: 2
+  common_eval:
+    path: runs/mtm/vlm/vttqa/checkpoint_last.pt
+model:
+  model_cls: MMFusionJoint
+  mm_encoder_cls: MMBertForJoint
+  use_seg_emb: true
+eval:
+  save_path: runs/mtm/vlm/vttqa/eval
+metric: QAMetric
+predictor: QAPredictor

+ 31 - 0
examples/MMPT/projects/mtm/vlm/test_youcook.yaml

@@ -0,0 +1,31 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+  split: test
+  video_processor: YoucookVideoProcessor
+  aligner: DSAligner
+  bert_name: bert-base-uncased
+  meta_processor: YoucookMetaProcessor
+  test_path: data/youcook/youcook_val.pkl
+  trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+  use_annotation_text: true
+  vfeat_dir: data/feat/feat_youcook_s3d
+  text_processor: TextProcessor
+  num_iso_layer: 12
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  dataset:
+    batch_size: 256
+    valid_subset: test
+    num_workers: 2
+  common_eval:
+    path: runs/mtm/vlm/youcook/checkpoint_last.pt
+model:
+  model_cls: MMFusionJoint
+  mm_encoder_cls: MMBertForJoint
+  use_seg_emb: true
+eval:
+  save_path: runs/mtm/vlm/youcook/eval
+metric: RetrievalMetric
+predictor: RetrievalPredictor

+ 32 - 0
examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml

@@ -0,0 +1,32 @@
+slurm_config: big
+task_type: local_predict
+dataset:
+  split: test
+  video_processor: YoucookVideoProcessor
+  aligner: DSNLGAligner
+  bert_name: bert-base-uncased
+  meta_processor: YoucookNLGMetaProcessor
+  test_path: data/youcook/val_list.txt
+  trainval_annotation: data/youcook/youcookii_annotations_trainval.json
+  vfeat_dir: data/feat/feat_youcook_s3d
+  text_processor: NLGTextProcessor
+  max_video_len: 32
+  max_len: 96
+fairseq:
+  dataset:
+    batch_size: 256
+    valid_subset: test
+    num_workers: 2
+  common_eval:
+    path: runs/mtm/vlm/youcookcap/checkpoint_best.pt
+model:
+  model_cls: MMFusionNLG
+  mm_encoder_cls: MMBertForNLG
+  max_decode_length: 24
+  use_seg_emb: true
+eval:
+  save_path: runs/mtm/vlm/youcookcap/eval
+metric: NLGMetric
+predictor: NLGPredictor
+gen_param:
+  num_beams: 5

Некоторые файлы не были показаны из-за большого количества измененных файлов